From 4b54e46005f8ae28f6f3a6f605d61f0616b779d0 Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:00:58 +0800 Subject: [PATCH 01/15] register attn backend to sgl from ATOM enable mla --- atom/config.py | 31 +- atom/model_ops/__init__.py | 1 + atom/model_ops/linear.py | 11 + atom/model_ops/moe.py | 43 +- atom/model_ops/radix_attention.py | 21 +- atom/models/deepseek_v2.py | 809 +++++++- atom/models/qwen3_moe.py | 93 +- atom/plugin/attention_backend/__init__.py | 0 .../attention_backend/sgl_attn_backend.py | 1773 +++++++++++++++++ atom/plugin/register.py | 6 +- atom/utils/envs.py | 1 + 11 files changed, 2749 insertions(+), 40 deletions(-) create mode 100644 atom/plugin/attention_backend/__init__.py create mode 100644 atom/plugin/attention_backend/sgl_attn_backend.py diff --git a/atom/config.py b/atom/config.py index 1d433987d..36312efdc 100644 --- a/atom/config.py +++ b/atom/config.py @@ -817,22 +817,27 @@ def __post_init__(self): # assert os.path.isdir(self.model) assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) - if rope_params is None: - rope_params = {} - rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) - rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} + rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = QuantizationConfig(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index 4b6c0b545..4e0c0c258 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -5,6 +5,7 @@ # it can be assigned to different attention ops. # By default, PagedAttention is used. # For sglang, RadixAttention will be assigned to Attention +# see register.py for details. Attention = PagedAttention __all__ = [ diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index ed7459614..7555fcbc3 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -388,10 +388,21 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) + _diag_forward_counter = 0 + @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 ) -> torch.Tensor: + if LinearBase._diag_forward_counter < 5: + LinearBase._diag_forward_counter += 1 + print( + f"[DIAG][LinearBase.forward] prefix={self.prefix} " + f"quant_type={self.quant_type} " + f"w_dtype={self.weight.dtype} " + f"w_scale_shape={tuple(self.weight_scale.shape) if self.weight_scale is not None else None} " + f"x shape={tuple(x.shape)}" + ) if self.quant_type.value == QuantType.No.value: y = tgemm.mm( x, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 0745df342..4b35b4e5c 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1401,8 +1401,7 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, ) else: - # Direct kernel call for non-EP/DP cases - return rocm_asm_moe_impl( + return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -1611,7 +1610,20 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: def _process_block_quant(self, layer: nn.Module) -> None: assert self.quant_config["is_dynamic"] + print( + f"[DIAG][Fp8MoE._process_block_quant] BEFORE normalize: " + f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " + f"w13_scale dtype={layer.w13_weight_scale.dtype} shape={tuple(layer.w13_weight_scale.shape)} " + f"need_normalize={self.need_normalize_e4m3fn_to_e4m3fnuz}" + ) self._normalize_weights_and_scales(layer) + print( + f"[DIAG][Fp8MoE._process_block_quant] AFTER normalize: " + f"w13 dtype={layer.w13_weight.dtype} " + f"w13_scale min={layer.w13_weight_scale.data.min().item():.6f} " + f"max={layer.w13_weight_scale.data.max().item():.6f} " + f"mean={layer.w13_weight_scale.data.float().mean().item():.6f}" + ) if not self.need_normalize_e4m3fn_to_e4m3fnuz: layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False) @@ -1624,6 +1636,7 @@ def _process_block_quant(self, layer: nn.Module) -> None: ) shuffle_weights(layer.w13_weight, layer.w2_weight) + print(f"[DIAG][Fp8MoE._process_block_quant] DONE shuffle") def _process_channel_quant(self, layer: nn.Module) -> None: """PTPTC""" @@ -1693,13 +1706,26 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, per_act_token_quant=True, ) + elif self.block_quant: + if self.quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif self.quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=block_shape, + ) else: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=None, ) @mark_trace(prefix="fp8_moe", torch_compile=False) @@ -1741,6 +1767,17 @@ def apply( # per_Tensor doesn't support num_local_tokens, so fallback to # rocm_aiter_fused_moe when using per-tensor or no modular kernel. if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: + if not hasattr(self, "_diag_apply_printed"): + self._diag_apply_printed = True + print( + f"[DIAG][Fp8MoE.apply] rocm_aiter path: " + f"quant_type={self.quant_type} " + f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " + f"w13_scale shape={tuple(layer.w13_weight_scale.shape)} " + f"w13_scale min={layer.w13_weight_scale.data.float().min().item():.6f} " + f"max={layer.w13_weight_scale.data.float().max().item():.6f} " + f"x dtype={x.dtype} shape={tuple(x.shape)}" + ) return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 25388b384..c85b0251e 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -9,6 +9,7 @@ from .base_attention import BaseAttention from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix +from atom.utils import envs class RadixAttention(BaseAttention): @@ -47,23 +48,35 @@ def __init__( prefix=prefix, **kwargs, ) - self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention + _v_head_dim = mla_modules.kv_lora_rank if (use_mla and mla_modules is not None) else head_dim + self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, scaling=scale, num_kv_heads=num_kv_heads, layer_id=layer_num, + v_head_dim=_v_head_dim, prefix=maybe_prefix(prefix, "attn"), ) + if self.attn.k_scale is None: + self.attn.k_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32), requires_grad=False + ) + if self.attn.v_scale is None: + self.attn.v_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32), requires_grad=False + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) + # if True, save cache will be done in rope + self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM def forward_impl_plugin_mode( self, @@ -82,10 +95,8 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - if self.rotary_emb is not None: - assert positions is not None, "positions is required for ROPE" - query, key = self.rotary_emb(positions, query, key) - return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index d1da9f052..fe84de8ed 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -23,8 +23,9 @@ # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" +import json import logging -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Iterable, Any import torch from aiter import ( @@ -38,7 +39,7 @@ top_k_per_row_prefill, ) from aiter.dist.communication_op import tensor_model_parallel_all_reduce -from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size, get_tp_group from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from aiter.ops.triton.fused_fp8_quant import ( @@ -70,6 +71,7 @@ RowParallelLinear, use_triton_gemm, ) +from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE, _has_module, quark_post_load_weights from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import ( is_rocm_aiter_fuse_routed_scaling_factor, @@ -89,9 +91,24 @@ from atom.utils.forward_context import get_forward_context from torch import nn from transformers import PretrainedConfig +from atom.plugin.prepare import is_sglang # from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 +from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context +from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.configs.model_config import is_deepseek_nsa +from sglang.srt.models.deepseek_common.utils import _use_aiter_gfx95,_use_aiter,_is_gfx95_supported +from sglang.srt.layers.quantization.rocm_mxfp4_utils import batched_gemm_afp4wfp4_pre_quant +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) logger = logging.getLogger("atom") if use_triton_gemm(): @@ -113,6 +130,34 @@ gemm_a8w8_blockscale_preshuffle = None gemm_a16w8_blockscale_preshuffle = None + +from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 + +from sglang.srt.utils.custom_op import register_custom_op + +# TODO(yuwei): remove this wrapper after sgl-kernel registers its own fake/meta impl +# Wrap bmm_fp8 as a custom op so torch.compile does not trace into +# torch.cuda.current_blas_handle() (which returns a non-Tensor). +@register_custom_op(mutates_args=["out"]) +def _bmm_fp8_op( + A: torch.Tensor, + B: torch.Tensor, + out: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) + +def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _bmm_fp8_op(A, B, out, A_scale, B_scale) + return out + ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION @@ -1431,11 +1476,542 @@ def __init__( self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - def forward( + # for sglang + self.use_nsa = is_deepseek_nsa(config) + self.use_deep_gemm_bmm = False + self.alt_stream = None + # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( + # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + self.w_kc, self.w_vc = None, None + self.w_scale = None + self.w_scale_k = None + self.w_scale_v = None + # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( + # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + def _forward_sgl_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ) -> torch.Tensor: + # supplementary code, port from forward_common + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + + forward_batch = model_kwargs.get("forward_batch", None) + zero_allocator = model_kwargs.get("zero_allocator", None) + llama_4_scaling = model_kwargs.get("llama_4_scaling", None) + q_lora = None + topk_indices = None + # #region agent log + try: + _pos_0 = int(positions.shape[0]) + _hs_0 = int(hidden_states.shape[0]) if hasattr(hidden_states, "shape") else -1 + _tp = int(get_tensor_model_parallel_world_size()) if forward_batch is not None else -1 + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "A", "location": "deepseek_v2.py:_forward_sgl_prepare_entry", "message": "prepare_entry", "data": {"positions_dim0": _pos_0, "hidden_states_dim0": _hs_0, "tp_world": _tp, "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + if self.q_lora_rank is not None: + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"positions={tuple(positions.shape)} hidden_states={tuple(hidden_states.shape)} " + f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)}" + ) + # qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) + # q, latent_cache = torch.split( + # qkv_lora, + # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + # dim=-1 + # ) + + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + ) + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"fetched q={tuple(q.shape)} latent={tuple(latent_cache.shape)} " + f"positions={tuple(positions.shape)} tp_world={get_tensor_model_parallel_world_size()}" + ) + # #region agent log + try: + _q0, _p0, _tp = int(q.shape[0]), int(positions.shape[0]), int(get_tensor_model_parallel_world_size()) + _fallback_cond = _q0 != _p0 and _tp > 1 + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "B,C", "location": "deepseek_v2.py:after_fetch", "message": "after_fetch", "data": {"q_dim0": _q0, "positions_dim0": _p0, "tp_world": _tp, "fallback_will_run": _fallback_cond}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + + if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: + qkv_lora = torch.cat([q, latent_cache], dim=-1) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] < positions.shape[0]: + raise RuntimeError( + f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, expected {positions.shape[0]}" + ) + qkv_lora = qkv_lora[: positions.shape[0]] + q, latent_cache = torch.split( + qkv_lora, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"after_fallback_gather q={tuple(q.shape)} latent={tuple(latent_cache.shape)}" + ) + # #region agent log + try: + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "C", "location": "deepseek_v2.py:after_fallback", "message": "after_fallback", "data": {"q_dim0": int(q.shape[0]), "positions_dim0": int(positions.shape[0])}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + + k_nope = latent_cache[..., : self.kv_lora_rank] + + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q = self.q_a_layernorm(q) + with torch.cuda.stream(self.alt_stream): + k_nope = self.kv_a_layernorm(k_nope) + current_stream.wait_stream(self.alt_stream) + else: + # if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: + # q, _, k_nope, *_ = fused_rms_mxfp4_quant( + # q, + # self.q_a_layernorm.weight, + # self.q_a_layernorm.variance_epsilon, + # k_nope, + # self.kv_a_layernorm.weight, + # self.kv_a_layernorm.variance_epsilon, + # ) + # else: + q_lora = None + _use_aiter_gfx95 = False + if ( + _use_aiter_gfx95 + and + self.q_b_proj.weight.dtype == torch.float8_e4m3fn + ): + if self.use_nsa: + q_quanted, q_lora, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=True, + ) + q = q_quanted + else: + q, _, k_nope, _ = fused_rms_fp8_group_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + group_size=128, + dtype_quant=torch.float8_e4m3fn, + res1=None, + output_unquantized_inp1=False, + ) + + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) + + # q_lora needed by indexer + if self.use_nsa: + if q_lora is None: + q_lora = q + + # overlap q_b_proj and indexer during decode + if ( + self.alt_stream is not None + and get_is_capture_mode() + and forward_batch.forward_mode.is_decode_or_idle() + and q_lora is not None + ): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + with torch.cuda.stream(self.alt_stream): + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q).view( + -1, self.num_local_heads, self.qk_head_dim + ) + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_num, + ) + current_stream.wait_stream(self.alt_stream) + else: + k_nope = k_nope.unsqueeze(1) + q = self.q_b_proj(q).view(-1, self.num_local_heads, self.qk_head_dim) + if q_lora is not None: + topk_indices = self.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=self.layer_num, + ) + else: + q = self.q_proj(hidden_states).view( + -1, self.num_local_heads, self.qk_head_dim + ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + k_nope = latent_cache[..., : self.kv_lora_rank] + k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) + # #region agent log + try: + with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: + _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "D,E", "location": "deepseek_v2.py:before_rope_assert", "message": "before_rope", "data": {"q_pe_dim0": int(q_pe.shape[0]), "k_pe_dim0": int(k_pe.shape[0]), "positions_dim0": int(positions.shape[0]), "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") + except Exception: + pass + # #endregion + print( + f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " + f"positions={tuple(positions.shape)}" + ) + + _is_hip= True + if self.use_deep_gemm_bmm: + q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1)) + ) + q_nope_out = q_nope.new_empty( + (self.num_local_heads, aligned_m, self.kv_lora_rank) + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (q_nope_val, q_nope_scale), + (self.w_kc, self.w_scale_k), + q_nope_out, + masked_m, + expected_m, + ) + q_nope_out = q_nope_out[:, :expected_m, :] + elif _is_hip: + # TODO(haishaw): add bmm_fp8 to ROCm + if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: + x = q_nope.transpose(0, 1) + q_nope_out = torch.empty( + x.shape[0], + x.shape[1], + self.w_kc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_kc.transpose(-2, -1), + self.w_scale_k.transpose(-2, -1), + torch.bfloat16, + q_nope_out, + ) + else: + if (_use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and self.w_kc.dtype == torch.float8_e4m3fnuz + ): + # fp8 Triton kernel: always on gfx950, + # cudagraph-only on gfx942 (hides launch overhead) + q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=q_nope, + WQ=self.w_kc.transpose(-1, -2), + w_scale=self.mla_attn.w_scale, + group_size=128, + YQ=None, # allocate (B, M, N) + transpose_bm=False, # (B, M, N) + transpose_bm_in=True, # (M, B, K) + dtype=torch.bfloat16, + ) + + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + + elif self.w_kc.dtype == torch.float8_e4m3fn: + # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=q_nope.device) + # if _is_cublas_ge_129 + # else zero_allocator.allocate(1) + ), + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + + q_nope_out = q_nope_out.transpose(0, 1) + + if ( + self.rotary_emb is not None + # and (not self._fuse_rope_for_trtllm_mla(forward_batch)) + and (not _use_aiter or not _is_gfx95_supported or self.use_nsa) + ): + # Optional hard check during debugging + assert q_pe.shape[0] == positions.shape[0], ( + f"q_pe tokens {q_pe.shape[0]} != positions {positions.shape[0]}" + ) + assert k_pe.shape[0] == positions.shape[0], ( + f"k_pe tokens {k_pe.shape[0]} != positions {positions.shape[0]}" + ) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + # support allgather+rerrange + k_nope, k_pe = self.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + # end forward prepare + return ( + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, + llama_4_scaling, + ) + + def _forward_sgl_core( + self, + q_pe, + k_pe, + q_nope_out, + k_nope, + forward_batch, + zero_allocator, + positions, + topk_indices, + llama_4_scaling, + ): + # 1) build q/k for radix attention path + _is_hip = True + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) + + if llama_4_scaling is not None: + q = q * llama_4_scaling + + # 2) attention core + attn_output = self.mla_attn( + q, + k, + k_nope, + forward_batch=forward_batch, + save_kv_cache=True, + **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), + ) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + + # 3) up-proj by w_vc (port from sglang forward_absorb_core) + if self.use_deep_gemm_bmm: + attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(attn_output.transpose(0, 1)) + ) + attn_bmm_output = attn_output.new_empty( + (self.num_local_heads, aligned_m, self.v_head_dim) + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (attn_output_val, attn_output_scale), + (self.w_vc, self.w_scale_v), + attn_bmm_output, + masked_m, + expected_m, + ) + attn_bmm_output = ( + attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2) + ) + + elif _is_hip: + if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8: + x = attn_output.transpose(0, 1) + y = torch.empty( + x.shape[0], + x.shape[1], + self.w_vc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_vc.transpose(-2, -1), + self.w_scale_v.transpose(-2, -1), + torch.bfloat16, + y, + ) + attn_bmm_output = y.transpose(0, 1).flatten(1, 2) + else: + if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: + y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=attn_output, + WQ=self.w_vc.transpose(-1, -2), + w_scale=self.w_scale, + group_size=128, + YQ=None, + transpose_bm=False, + transpose_bm_in=True, + dtype=torch.bfloat16, + ) + else: + y = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) + attn_bmm_output = y.transpose(0, 1).flatten(1, 2) + + elif self.w_vc.dtype == torch.float8_e4m3fn: + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=attn_output.device), + ) + attn_bmm_output = bmm_fp8( + attn_output_val, + self.w_vc, + attn_output_scale, + self.w_scale, + torch.bfloat16, + ) + attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + + else: + attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) + attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + + output = self.o_proj(attn_bmm_output) + return output + + def prepare_qkv_latent( + self, + hidden_states: torch.Tensor, + forward_batch, + ): + assert self.q_lora_rank is not None + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " + f"hidden_states={tuple(hidden_states.shape)} " + f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) + print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") + + # Fallback: when communicator does not enable input_scattered gather, + # force qkv latent token dimension to align with positions. + # Use positions.shape[0] (actual input token count) instead of + # seq_lens_sum (total KV cache length, wrong for decode mode). + expected_tokens = 0 + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + expected_tokens = int(forward_batch.positions.shape[0]) + if expected_tokens <= 0: + expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) + + if ( + expected_tokens > 0 + and qkv_lora.shape[0] != expected_tokens + and get_tensor_model_parallel_world_size() > 1 + ): + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " + f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " + f"tp_world={get_tensor_model_parallel_world_size()}" + ) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] > expected_tokens: + qkv_lora = qkv_lora[:expected_tokens] + elif qkv_lora.shape[0] < expected_tokens: + raise RuntimeError( + f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {expected_tokens}" + ) + print( + f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" + ) + return qkv_lora + + + def forward_sgl_plugin_mode( self, positions: torch.Tensor, hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: + # forward_absorb_prepare sglang + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + forward_batch = model_kwargs.get("forward_batch", None) + if forward_batch is None: + raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") + + attn_tp_context = get_attn_tp_context() + print( + f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + f"positions={tuple(positions.shape)} " + f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " + f"allow_scatter={attn_tp_context.allow_input_scattered}" + ) + with attn_tp_context.maybe_input_scattered(forward_batch): + print( + f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + f"input_scattered={attn_tp_context.input_scattered}" + ) + if self.q_lora_rank is not None: + attn_tp_context.set_attn_inputs( + AttentionInputs( + hidden_states, + forward_batch, + self.prepare_qkv_latent, + ) + ) + prepared = self._forward_sgl_prepare(positions, hidden_states, **model_kwargs) + return self._forward_sgl_core(*prepared) + + def forward_common( + + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ): hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states @@ -1522,7 +2098,200 @@ def forward( k_pe, positions, hidden_states_or_q_c_scale, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ) -> torch.Tensor: + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode(positions, hidden_states, **model_kwargs) + else: + attn_output = self.forward_common(positions, hidden_states, **model_kwargs) + return attn_output + + def process_weights_after_loading(self) -> None: + # only for sglang plugin mode + if not is_sglang(): + return + self._process_mla_kv_b_proj_after_loading_sgl() + + def _process_mla_kv_b_proj_after_loading_sgl(self) -> None: + # lazy imports: only needed for sglang plugin path + from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz + from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + ) + from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, + ) + from sglang.srt.layers.deep_gemm_wrapper import ( + ENABLE_JIT_DEEPGEMM, + DEEPGEMM_BLACKWELL, ) + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 + from sglang.srt.models.deepseek_common.utils import ( + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_hip, + _is_npu, + _use_aiter_gfx95, + awq_dequantize_func, + ) + from sglang.srt.utils import bind_or_assign, get_bool_env_var + + # read kv_b_proj weight (awq compatible) + if hasattr(self.kv_b_proj, "qweight"): + awq_dequantize_f = awq_dequantize_func() + if awq_dequantize_f is None: + raise ValueError("AWQ dequantize function is not supported for current device") + w = awq_dequantize_f( + self.kv_b_proj.qweight, + self.kv_b_proj.scales, + self.kv_b_proj.qzeros, + ).T + else: + w = self.kv_b_proj.weight + + # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes + # into them (weight_loader_process view-casts a detached copy, leaving + # the nn.Parameter as fnuz). At this point LinearBase's + # process_weights_after_loading hasn't run yet (parent module iterates + # before child in named_modules). View-cast back to fn so the + # normalize path works correctly. + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: + w = w.view(torch.float8_e4m3fn) + + use_deep_gemm_bmm = False + block_scale = None + weight_block_size = None + + # Derive weight_block_size from ATOM's quant_type system + from aiter import QuantType as _AiterQuantType + _atom_qt = getattr(self.kv_b_proj, "quant_type", None) + if _atom_qt == _AiterQuantType.per_1x128: + weight_block_size = [128, 128] + elif _atom_qt == _AiterQuantType.per_1x32: + weight_block_size = [1, 32] + + # fp8 path + if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + if weight_block_size is not None: + assert hasattr(self.kv_b_proj, "weight_scale_inv") or hasattr(self.kv_b_proj, "weight_scale") + weight_scale = ( + self.kv_b_proj.weight_scale + if hasattr(self.kv_b_proj, "weight_scale") + else self.kv_b_proj.weight_scale_inv + ) + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=weight_scale, + input_scale=None, + ) + else: + weight = w + + if ( + should_deepgemm_weight_requant_ue8m0( + weight_block_size=weight_block_size + ) + and getattr(weight_scale, "format_ue8m0", False) + ): + weight_scale = inverse_transform_scale_ue8m0(weight_scale, mn=weight.shape[-2]) + + if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: + if ( + ENABLE_JIT_DEEPGEMM + and not DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant(weight, weight_scale, weight_block_size) + self.w_scale = scale + else: + if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self.w_scale = scale + + # int8 path + if w.dtype == torch.int8: + if weight_block_size is not None: + assert hasattr(self.kv_b_proj, "weight_scale_inv") + w = int8_block_dequant( + w, + self.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(torch.bfloat16) + else: + w = w.to(torch.bfloat16) * self.kv_b_proj.weight_scale.to(torch.bfloat16) + + # split to kc/vc + w_kc, w_vc = w.unflatten( + 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + # quark fp4 special path (ATOM utility) + quant_method = getattr(self.kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + if _use_aiter_gfx95 and quant_config is not None and quant_config.get_name() == "quark": + w_kc, self.w_scale_k, w_vc, self.w_scale_v = quark_post_load_weights(self, w, "mxfp4") + + if not use_deep_gemm_bmm: + self.w_kc = bind_or_assign( + self.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + self.w_vc = bind_or_assign(self.w_vc, w_vc) + + if hasattr(self.kv_b_proj, "weight_scale") and self.w_scale is None: + self.w_scale = bind_or_assign(self.w_scale, self.kv_b_proj.weight_scale) + if _is_hip: + self.w_scale *= 2.0 + + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self.w_kc = self.w_kc.to(torch.bfloat16) * self.w_scale + self.w_vc = self.w_vc.to(torch.bfloat16) * self.w_scale + else: + num_tiles_k = self.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten(0, (-1, (num_tiles_k + num_tiles_n))).split( + [num_tiles_k, num_tiles_n], dim=1 + ) + + self.w_scale_k = bind_or_assign(self.w_scale_k, ws_kc.transpose(1, 2).contiguous()) + self.w_scale_v = bind_or_assign(self.w_scale_v, ws_vc.contiguous()) + self.w_kc = bind_or_assign(self.w_kc, w_kc.transpose(1, 2).contiguous()) + self.w_vc = bind_or_assign(self.w_vc, w_vc.contiguous()) + self.use_deep_gemm_bmm = True class DeepseekV2DecoderLayer(nn.Module): @@ -1627,14 +2396,20 @@ def __init__( self.fuse_rmsnorm_quant = ( ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None ) - def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: # Self Attention + print( + f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " + f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + f"residual={None if residual is None else tuple(residual.shape)} " + f"fuse_input_norm_quant={self.fuse_input_norm_quant}" + ) if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight @@ -1689,6 +2464,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) if hidden_states.dtype == torch.float16: @@ -1788,6 +2564,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -1801,7 +2578,7 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions, hidden_states, residual, **model_kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1840,6 +2617,7 @@ def __init__( quant_config = atom_config.quant_config self.config = config self.quant_config = quant_config + self.atom_config = atom_config if hasattr(config, "q_lora_rank") and config.q_lora_rank is not None: self.packed_modules_mapping = { @@ -1873,6 +2651,10 @@ def __init__( self.model.make_empty_intermediate_tensors ) + if is_sglang(): + from sglang.srt.configs.model_config import is_deepseek_nsa + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -1882,9 +2664,11 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids, positions, intermediate_tensors, inputs_embeds, + **model_kwargs, ) return hidden_states @@ -1912,6 +2696,19 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + + # lazy import to avoid circular import issue since model_loader also imports model.. + from atom.model_loader.loader import load_model_in_plugin_mode + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 4d68bd42e..71eec8b17 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,11 +1,11 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope +from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -33,6 +33,8 @@ ) from atom.utils import envs from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -41,7 +43,7 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) - +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM class Qwen3MoeMLP(nn.Module): def __init__( @@ -229,6 +231,61 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + def forward_sgl_plugin_mode( + self, + positions: torch.Tensor, + qkv: torch.Tensor, + **model_kwargs: dict[str, Any] | None, + ): + if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: + forward_batch = model_kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + block_size = 1024 # Default fallback + if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache = (k_buffer, v_buffer), + cache_loc = forward_batch.out_cache_loc, + k_scale = self.k_scale, + v_scale = self.v_scale, + return_kv = True, + use_shuffle_layout = True, + block_size = block_size, + x = x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn( + q, k, v, positions=positions, **model_kwargs + ) + return attn_output def forward( self, @@ -246,13 +303,16 @@ def forward( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + else: + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output @@ -266,7 +326,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = rope_params + rope_scaling = None if rope_params["rope_type"] == "default" else rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix @@ -348,7 +408,7 @@ def forward( @support_torch_compile -class Qwen3MoeModel(nn.Module): +class Qwen3MoeModel(torch.nn.Module): def __init__( self, atom_config: Config, @@ -522,3 +582,14 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record \ No newline at end of file diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/attention_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py new file mode 100644 index 000000000..2ececc348 --- /dev/null +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -0,0 +1,1773 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +import sglang.srt.layers.attention.aiter_backend as _sglang_aiter +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import get_bool_env_var + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + mla_decode_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +# MLA prefill kernels - imported separately to avoid breaking the main aiter imports +mla_prefill_ps_asm_fwd = None +mla_reduce_v1 = None +mla_prefill_fwd = None +try: + from aiter import mla_prefill_ps_asm_fwd +except ImportError: + pass +try: + from aiter import mla_reduce_v1 +except ImportError: + pass +try: + from aiter.mla import mla_prefill_fwd + from aiter.mla import mla_decode_fwd +except ImportError: + pass + +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + + block_offset // x * head_size * x + + offset * x + + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + +@dataclass +class ForwardMetadata: + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # mla + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None + fp8_prefill_kv_indices: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) + # prefill_pages_kv_indptr: Optional[torch.Tensor] = None + # prefill_kv_indices: Optional[torch.Tensor] = None + # prefill_kv_last_page_lens: Optional[torch.Tensor] = None + + + +class ATOMAttnBackendForSgl(AiterAttnBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + + + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + qo_indptr = None + kv_last_page_len = None + max_q_len = None + page_table = None + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) + kv_last_page_len = self.kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + work_metadata, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, # max_kv_len + None, # page_table + None, # kv_lens + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + + else: + if self.decode_using_pa_ps: + # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, # qo_indptr not used in non-MLA mode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, + page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, + seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ) + + # Build pa_metadata for pa_persistent_fwd + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + # return # Early return for non-MLA decode mode + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.use_mla: + # raise NotImplementedError("MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl.") + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_seq_lens, + forward_batch.extend_seq_lens.max().item(), + forward_batch.seq_lens.max().item(), + spec_info=None + ) + + max_q_len = self.mla_indices_updater_prefill.max_q_len + qo_indptr = self.mla_indices_updater_prefill.qo_indptr + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + fp8_prefill_kv_indices = None + reduce_partial_map = None + + from sglang.srt.utils import is_gfx95_supported + _use_fp8_prefill_attn = ( + get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() + ) + if _use_fp8_prefill_attn: + tile_q = 256 + qlen_granularity = tile_q // (self.num_head // self.num_kv_head) + ( + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map + ) = self.make_mla_prefill_ps_meta_data_buffer( + bs, max_q_len, qlen_granularity + ) + + + self.make_mla_prefill_ps_meta_data( + qo_indptr, + qo_indptr, + forward_batch.seq_lens, + work_metadata, + work_indptr, + work_info_set, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + is_causal=True, + ) + + total_s = int(forward_batch.extend_seq_lens.sum()) + fp8_prefill_kv_indices = torch.arange( + total_s, device=self.device, dtype=torch.int32 + ) + + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + self.mla_indices_updater_prefill.max_kv_len, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + fp8_prefill_kv_indices=fp8_prefill_kv_indices, + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + # Get page_table for mha_batch_prefill_func + page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + if (forward_batch.forward_mode.is_extend() and + not self.use_mla and + self.forward_metadata.page_table is not None): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + ) + + def _allocate_pa_metadata_buffers( + self, + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ): + """Allocate or reuse pa_metadata buffers.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + + def _get_size_val(size): + return size[0] if isinstance(size, tuple) else size + + # Allocate work_metadata_ptrs + size_val = _get_size_val(work_metadata_ptrs_size) + if ("work_metadata_ptrs" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( + work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + ) + + # Allocate work_indptr + size_val = _get_size_val(work_indptr_size) + if ("work_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["work_indptr"] = torch.zeros( + work_indptr_size, dtype=work_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_indptr"].zero_() + + # Allocate work_info + size_val = _get_size_val(work_info_size) + if ("work_info" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or + self.pa_metadata_buffers["work_info"].shape[0] < size_val): + self.pa_metadata_buffers["work_info"] = torch.zeros( + work_info_size, dtype=work_info_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_info"].zero_() + + # Allocate reduce_indptr + size_val = _get_size_val(reduce_indptr_size) + if ("reduce_indptr" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( + reduce_indptr_size, dtype=reduce_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_indptr"].zero_() + + # Allocate reduce_final_map + size_val = _get_size_val(reduce_final_map_size) + if ("reduce_final_map" not in self.pa_metadata_buffers or + len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or + self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( + reduce_final_map_size, dtype=reduce_final_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_final_map"].zero_() + + # Allocate reduce_partial_map + reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] + if ("reduce_partial_map" not in self.pa_metadata_buffers or + self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_partial_map"].zero_() + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + batch_size, + self.num_kv_head, + ) + # Allocate metadata buffers with reuse optimization for multi-layer forward passes + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # kv_indices = self.pa_kv_indices + + # Compute kv_last_page_lens for each sequence + # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() + + # Store in ForwardMetadata for reuse in forward_extend + # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr + # self.forward_metadata.prefill_kv_indices = kv_indices + # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=self.device + ) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + if self.use_mla and _sglang_aiter._use_mla_ps_kernel: + max_seqlen_qo = 1 + ( + self.work_metadata, + self.work_indptr, + self.work_info_set, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) + elif self.use_mla: + self.work_metadata = None + self.work_indptr = None + self.work_info_set = None + self.reduce_indptr = None + self.reduce_final_map = None + self.reduce_partial_map = None + + if self.decode_using_pa_ps and not self.use_mla: + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + max_bs, + self.num_kv_head, + ) + + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + else: + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table, + seq_lens_persistent, + ) + + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + return + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + else: + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + + self.forward_metadata = ForwardMetadata( + None, + None, + None, + None, + 1, + None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + raise ValueError("Invalid forward mode") + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, k, v, k_buffer, v_buffer, + layer.k_scale, layer.v_scale, self.page_size, + ) + + if self.use_mla: + return self._forward_extend_mla(q, k, v, layer, forward_batch) + else: + return self._forward_extend_mha(q, k, v, layer, forward_batch) + + def _forward_extend_mha(self, q, k, v, layer, forward_batch): + """Non-MLA extend path: standard MHA with flash_attn_varlen_func.""" + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def _forward_extend_mla(self, q, k, v, layer, forward_batch): + """MLA extend path: ported from sglang aiter_backend forward_extend MLA logic.""" + max_q_len = self.forward_metadata.max_q_len + max_kv_len = self.forward_metadata.max_kv_len + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + qo_indptr = self.forward_metadata.qo_indptr + + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + kv_lora_rank = V_Buffer.shape[-1] + qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank + qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim + + assert len(q.shape) == 3 + assert len(k.shape) == 3 + assert len(v.shape) == 3 + + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return self._forward_extend_mla_normal( + q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ) + elif forward_batch.forward_mode.is_target_verify(): + return self._forward_extend_mla_target_verify( + q, layer, K_Buffer, qo_indptr, + ) + elif forward_batch.forward_mode.is_draft_extend(): + return self._forward_extend_mla_draft_extend( + q, layer, K_Buffer, qo_indptr, + ) + else: + raise ValueError( + f"Invalid forward mode for MLA extend: {forward_batch.forward_mode=}" + ) + + def _forward_extend_mla_normal( + self, q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ): + """Normal MLA extend (not target_verify, not draft_extend). + + Three sub-paths mirroring sglang aiter_backend: + 1) No prefix -> fp8 prefill kernel (mla_prefill_ps_asm_fwd) or flash_attn fallback + 2) Has prefix, absorbed weights differ -> decompress via kv_b_proj + flash_attn + 3) Has prefix, qk_head_dim matches -> mla_prefill_fwd kernel + """ + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + + if kv_indices.shape[0] == 0 or extend_no_prefix: + # --- Sub-path 1: no prefix, pure prefill --- + use_fp8_prefill = ( + self.forward_metadata.fp8_prefill_kv_indices is not None + ) + if use_fp8_prefill: + total_s = q.shape[0] + nhead = layer.tp_q_head_num + v_head_dim = layer.v_head_dim + + if q.dtype != dtypes.fp8: + q = q.to(dtypes.fp8) + if k.dtype != dtypes.fp8: + k = k.to(dtypes.fp8) + if v.dtype != dtypes.fp8: + v = v.to(dtypes.fp8) + one_scale = torch.ones( + (), dtype=torch.float32, device=q.device + ) + + kv_indptr_asm = qo_indptr + kv_indices_asm = self.forward_metadata.fp8_prefill_kv_indices + + tile_q = 256 + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + logits = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), + dtype=torch.float32, + device=q.device, + ) + attn_lse = torch.empty( + (reduce_partial_map.size(0) * tile_q, nhead), + dtype=torch.float32, + device=q.device, + ) + final_lse = torch.empty( + (total_s, nhead), + dtype=torch.float32, + device=q.device, + ) + output = q.new_empty( + (total_s, nhead, v_head_dim), + dtype=self.input_dtype, + ) + + mla_prefill_ps_asm_fwd( + q, + k, + v, + qo_indptr, + kv_indptr_asm, + kv_indices_asm, + self.forward_metadata.work_indptr, + self.forward_metadata.work_info_set, + max_q_len, + layer.scaling, + True, + logits, + attn_lse, + output, + one_scale, + one_scale, + one_scale, + ) + mla_reduce_v1( + logits, + attn_lse, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + tile_q, + output, + final_lse, + ) + elif layer.qk_head_dim == (kv_lora_rank + qk_rope_head_dim) and mla_prefill_fwd is not None: + # Absorbed MLA: head_dim (576) exceeds CK limit (256), + # use mla_prefill_fwd which natively supports large MLA head dims. + # For no-prefix, use input k (bfloat16) directly instead of K_Buffer + # (which may be FP8). mla_prefill_fwd doesn't support FP8 KV. + if layer.qk_head_dim != layer.v_head_dim: + output = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + output = torch.empty_like(q) + total_s = q.shape[0] + temp_kv_indices = torch.arange( + total_s, device=q.device, dtype=torch.int32 + ) + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.view(-1, 1, 1, layer.qk_head_dim), + output.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, + qo_indptr, + temp_kv_indices, + self.forward_metadata.kv_last_page_len, + max_q_len, + layer.scaling, + layer.logit_cap, + ) + else: + output = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + qo_indptr, + max_q_len, + max_q_len, + softmax_scale=layer.scaling, + causal=True, + ) + return output + + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + # --- Sub-path 2: has prefix, need kv_b_proj decompress --- + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split( + K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 + ) + + if self.kv_cache_dtype == dtypes.fp8: + dtype = q.dtype + kvc = kvc.to(dtype) + k_pe = k_pe.to(dtype) + + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ + k_prefix, + torch.broadcast_to( + k_pe, + (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) + + assert ( + forward_batch.extend_prefix_lens.shape + == forward_batch.extend_seq_lens.shape + ) + + o = flash_attn_varlen_func( + q, + k_prefix, + v_prefix, + qo_indptr, + kv_indptr, + max_q_len, + max_kv_len, + softmax_scale=layer.scaling, + causal=True, + ) + return o + + else: + # --- Sub-path 3: has prefix, qk_head_dim == kv_lora_rank + qk_rope_head_dim --- + # Gather needed KV entries and cast to bf16 (K_Buffer may be FP8) + k_selected = torch.index_select(K_Buffer, 0, kv_indices) + if k_selected.dtype != q.dtype: + k_selected = k_selected.to(q.dtype) + compact_kv_indices = torch.arange( + k_selected.shape[0], device=q.device, dtype=torch.int32 + ) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + o = torch.empty_like(q) + + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_selected.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, + kv_indptr, + compact_kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, + layer.logit_cap, + ) + return o + + def _forward_extend_mla_target_verify(self, q, layer, K_Buffer, qo_indptr): + """MLA target_verify path (speculative decoding verification).""" + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + return o + + def _forward_extend_mla_draft_extend(self, q, layer, K_Buffer, qo_indptr): + """MLA draft_extend path (speculative decoding draft extension).""" + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q, + K_Buffer.view(-1, 1, 1, layer.qk_head_dim), + o, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + return o + + + def forward_decode_pa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + + if self.use_mla: + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + x = 16 // k_buffer.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q, + K=new_key_cache, + V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, + V_QScale=self.v_scale, + out_=o, + ) + return o + + def forward_decode_pa_ps( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm + # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) + # Use q.shape[0] instead of forward_batch.batch_size to be safe + batch_size = q.shape[0] + head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + + if self.use_mla: + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + + num_kv_splits = self.forward_metadata.num_kv_splits + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + num_slots, num_kv_heads, head_size = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + + + quant_dtype = dtypes.fp8 + x = 16 // quant_dtype.itemsize + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + + total_tokens = num_blocks * block_size + k_qscale = self.k_qscale[:, :total_tokens] + v_qscale = self.v_qscale[:, :total_tokens] + + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + + + assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr + kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr + kv_indices = self.forward_metadata.pa_metadata_kv_indices + context_lens = self.forward_metadata.pa_metadata_context_lens + max_qlen = self.forward_metadata.pa_metadata_max_qlen + + + _, _ = pa_persistent_fwd( + Q=q, + K=new_key_cache, + V=new_value_cache, + output=o, + max_qlen=max_qlen, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + context_lens=context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=k_qscale, + V_QScale=v_qscale, + softmax_scale=layer.scaling, + mask=1, + ) + return o.view(-1, layer.tp_q_head_num * head_dim_out) + + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if self.use_mla: + return self._forward_decode_mla(q, k, v, layer, forward_batch, save_kv_cache) + else: + if self.decode_using_pa_ps: + return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + else: + return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) + + def _forward_decode_mla(self, q, k, v, layer, forward_batch, save_kv_cache): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=self.input_dtype, + ) + else: + o = torch.empty_like(q, dtype=self.input_dtype) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + work_metadata = self.forward_metadata.work_metadata + work_indptr = self.forward_metadata.work_indptr + work_info_set = self.forward_metadata.work_info_set + reduce_indptr = self.forward_metadata.reduce_indptr + reduce_final_map = self.forward_metadata.reduce_final_map + reduce_partial_map = self.forward_metadata.reduce_partial_map + num_kv_splits = self.forward_metadata.num_kv_splits + + if layer.layer_id == 0: + _q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + _k_view = k_buffer.view(-1, 1, 1, layer.qk_head_dim) + _o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) + print( + f"[MLA_DECODE_DBG] layer=0" + f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" + f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" + f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" + f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" + f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" + f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" + f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" + f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" + f" max_q_len={self.forward_metadata.max_q_len}" + f" sm_scale={layer.scaling}" + f" logit_cap={layer.logit_cap}" + f" k_scale={layer.k_scale}" + f" num_kv_splits={num_kv_splits}" + f" page_size={self.page_size}" + f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" + f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" + f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" + f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" + f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" + f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" + f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" + f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" + f" fast_mode={_sglang_aiter.fast_mode}" + , flush=True, + ) + + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + sm_scale=layer.scaling, + logit_cap=layer.logit_cap, + work_meta_data=work_metadata, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=layer.k_scale, + kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=num_kv_splits, + ) + + return o diff --git a/atom/plugin/register.py b/atom/plugin/register.py index af2427fbf..55ee75bfb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -3,6 +3,7 @@ from atom.models.qwen3 import Qwen3ForCausalLM from atom.models.qwen3_moe import Qwen3MoeForCausalLM from atom.models.glm4_moe import Glm4MoeForCausalLM +from atom.models.deepseek_v2 import DeepseekV3ForCausalLM from atom.config import Config from atom.plugin.prepare import is_vllm, is_sglang @@ -12,6 +13,7 @@ "Qwen3ForCausalLM": Qwen3ForCausalLM, "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "Glm4MoeForCausalLM": Glm4MoeForCausalLM, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, } @@ -28,9 +30,9 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl - return AiterAttnBackend(runner) + return ATOMAttnBackendForSgl(runner) def register_ops_to_sglang(atom_config: Config) -> None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index ea543b5f7..ae3345146 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -75,6 +75,7 @@ "ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "0" ).lower() == "1", + "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", } From d2014329352f243a9c66d44e31e8a86f182032e3 Mon Sep 17 00:00:00 2001 From: Guanbao Yu Date: Wed, 11 Feb 2026 19:03:24 +0800 Subject: [PATCH 02/15] make format happy --- atom/config.py | 4 +- atom/model_ops/radix_attention.py | 8 +- atom/models/qwen3_moe.py | 39 ++- .../attention_backend/sgl_attn_backend.py | 309 ++++++++++++------ 4 files changed, 237 insertions(+), 123 deletions(-) diff --git a/atom/config.py b/atom/config.py index 36312efdc..62d77209d 100644 --- a/atom/config.py +++ b/atom/config.py @@ -831,7 +831,9 @@ def __post_init__(self): if ( eos_ids := getattr(self.generation_config, "eos_token_id", None) ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index c85b0251e..4340311b0 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -96,7 +96,13 @@ def forward_impl_plugin_mode( forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM - return self.attn(query, key, value, forward_batch=forward_batch, save_kv_cache=not self.use_aiter_rope_fused_qknorm) + return self.attn( + query, + key, + value, + forward_batch=forward_batch, + save_kv_cache=not self.use_aiter_rope_fused_qknorm, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 71eec8b17..279845595 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -45,6 +45,7 @@ ) ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM + class Qwen3MoeMLP(nn.Module): def __init__( self, @@ -243,24 +244,30 @@ def forward_sgl_plugin_mode( if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: forward_batch = model_kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_num) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) block_size = 1024 # Default fallback - if hasattr(forward_batch, 'attn_backend') and hasattr(forward_batch.attn_backend, 'page_size'): + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'allocator') and hasattr(forward_batch.token_to_kv_pool.allocator, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, 'page_size'): + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): block_size = forward_batch.token_to_kv_pool.page_size x = 16 // k_buffer.element_size() aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( - kv_cache = (k_buffer, v_buffer), - cache_loc = forward_batch.out_cache_loc, - k_scale = self.k_scale, - v_scale = self.v_scale, - return_kv = True, - use_shuffle_layout = True, - block_size = block_size, - x = x, + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, ) q, k, v = self.rotary_emb( qkv, @@ -282,9 +289,7 @@ def forward_sgl_plugin_mode( q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn( - q, k, v, positions=positions, **model_kwargs - ) + attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) return attn_output def forward( @@ -304,7 +309,9 @@ def forward( ) else: if is_sglang(): - attn_output = self.forward_sgl_plugin_mode(positions, qkv, **model_kwargs) + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) else: # Add qk-norm q = self.q_norm(q) diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 2ececc348..7f2216511 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -93,10 +93,7 @@ def reshape_and_cache_shuffle_kernel( dst_offset + offset // x * block_size * x + block_offset * x + offset % x ) dst_v_shuffle_offset = ( - dst_offset - + block_offset // x * head_size * x - + offset * x - + block_offset % x + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x ) k_val = tl.load(key_ptr + src_offset_k + offset) v_val = tl.load(value_ptr + src_offset_v + offset) @@ -110,6 +107,7 @@ def reshape_and_cache_shuffle_kernel( tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + def reshape_and_cache_shuffle_triton( key: torch.Tensor, value: torch.Tensor, @@ -161,6 +159,7 @@ def reshape_and_cache_shuffle_triton( QUANT=QUANT, ) + @dataclass class ForwardMetadata: # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode @@ -214,7 +213,13 @@ def __init__( self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building +<<<<<<< HEAD # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." +======= + assert ( + not self.use_mla + ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." +>>>>>>> bfc8900 (make format happy) # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -226,7 +231,9 @@ def __init__( (max_bs,), dtype=torch.int32, device=model_runner.device ) self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=model_runner.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, ) # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) self.strided_indices = torch.arange( @@ -235,7 +242,9 @@ def __init__( if not self.use_mla: # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) - max_num_blocks_per_seq = (self.max_context_len + self.page_size - 1) // self.page_size + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size max_total_blocks = max_bs * max_num_blocks_per_seq self.pa_kv_indices = torch.zeros( max_total_blocks, dtype=torch.int32, device=self.device @@ -251,13 +260,12 @@ def __init__( # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) - self.logits_soft_cap = 0.0 self.forward_metadata: ForwardMetadata = None - + self.pa_metadata_buffers = None - + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) num_slots, num_kv_heads, _ = k_buffer.shape block_size = self.page_size @@ -271,9 +279,6 @@ def __init__( ) self.decode_using_pa_ps = self.page_size == 1024 - - - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" bs = forward_batch.batch_size @@ -370,34 +375,53 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) page_table_persistent = self.page_table seq_lens_persistent = self.seq_lens seq_lens_persistent.fill_(0) page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 - page_table = self.req_to_token[forward_batch.req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] - page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) else: - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] - + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, None, # qo_indptr not used in non-MLA mode None, # kv_last_page_len not used in non-MLA mode - 1, # max_q_len = 1 for decode mode + 1, # max_q_len = 1 for decode mode None, - page_table_persistent[:bs, :max_seq_pages] if self.decode_using_pa_ps else page_table, - seq_lens_persistent[:bs] if self.decode_using_pa_ps else forward_batch.seq_lens, + ( + page_table_persistent[:bs, :max_seq_pages] + if self.decode_using_pa_ps + else page_table + ), + ( + seq_lens_persistent[:bs] + if self.decode_using_pa_ps + else forward_batch.seq_lens + ), ) - + # Build pa_metadata for pa_persistent_fwd if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) - # return # Early return for non-MLA decode mode + # return # Early return for non-MLA decode mode else: prefix_lens = forward_batch.extend_prefix_lens @@ -488,11 +512,15 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info=None, ) # Get page_table for mha_batch_prefill_func - page_table = forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :] + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, - self.qo_indptr[: bs + 1], # qo_indptr is set by indices_updater_prefill + self.qo_indptr[ + : bs + 1 + ], # qo_indptr is set by indices_updater_prefill None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, @@ -500,22 +528,34 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens, ) - if (forward_batch.forward_mode.is_extend() and - not self.use_mla and - self.forward_metadata.page_table is not None): + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): if self.page_size > 1: seq_lens_cpu = forward_batch.seq_lens_cpu if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.cpu() - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices[:max_seq_pages]] // self.page_size + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size ) if self.decode_using_pa_ps: self._build_pa_metadata_for_prefill(forward_batch.batch_size) - if not self.decode_using_pa_ps and self.page_size > 1 and self.forward_metadata.page_table is not None: + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): self.forward_metadata.page_table = ( - self.forward_metadata.page_table[:, self.strided_indices] // self.page_size + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size ) def _allocate_pa_metadata_buffers( @@ -536,90 +576,112 @@ def _allocate_pa_metadata_buffers( """Allocate or reuse pa_metadata buffers.""" if self.pa_metadata_buffers is None: self.pa_metadata_buffers = {} - + def _get_size_val(size): return size[0] if isinstance(size, tuple) else size - + # Allocate work_metadata_ptrs size_val = _get_size_val(work_metadata_ptrs_size) - if ("work_metadata_ptrs" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val): + if ( + "work_metadata_ptrs" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val + ): self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( - work_metadata_ptrs_size, dtype=work_metadata_ptrs_type, device=self.device + work_metadata_ptrs_size, + dtype=work_metadata_ptrs_type, + device=self.device, ) - + # Allocate work_indptr size_val = _get_size_val(work_indptr_size) - if ("work_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["work_indptr"].shape[0] < size_val): + if ( + "work_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["work_indptr"] = torch.zeros( work_indptr_size, dtype=work_indptr_type, device=self.device ) else: self.pa_metadata_buffers["work_indptr"].zero_() - + # Allocate work_info size_val = _get_size_val(work_info_size) - if ("work_info" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) or - self.pa_metadata_buffers["work_info"].shape[0] < size_val): + if ( + "work_info" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) + or self.pa_metadata_buffers["work_info"].shape[0] < size_val + ): self.pa_metadata_buffers["work_info"] = torch.zeros( work_info_size, dtype=work_info_type, device=self.device ) else: self.pa_metadata_buffers["work_info"].zero_() - + # Allocate reduce_indptr size_val = _get_size_val(reduce_indptr_size) - if ("reduce_indptr" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val): + if ( + "reduce_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( reduce_indptr_size, dtype=reduce_indptr_type, device=self.device ) else: self.pa_metadata_buffers["reduce_indptr"].zero_() - + # Allocate reduce_final_map size_val = _get_size_val(reduce_final_map_size) - if ("reduce_final_map" not in self.pa_metadata_buffers or - len(self.pa_metadata_buffers["reduce_final_map"].shape) < len(reduce_final_map_size) or - self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val): + if ( + "reduce_final_map" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["reduce_final_map"].shape) + < len(reduce_final_map_size) + or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val + ): self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( reduce_final_map_size, dtype=reduce_final_map_type, device=self.device ) else: self.pa_metadata_buffers["reduce_final_map"].zero_() - + # Allocate reduce_partial_map - reduce_partial_map_size_val = reduce_partial_map_size if isinstance(reduce_partial_map_size, int) else reduce_partial_map_size[0] - if ("reduce_partial_map" not in self.pa_metadata_buffers or - self.pa_metadata_buffers["reduce_partial_map"].shape[0] < reduce_partial_map_size_val): + reduce_partial_map_size_val = ( + reduce_partial_map_size + if isinstance(reduce_partial_map_size, int) + else reduce_partial_map_size[0] + ) + if ( + "reduce_partial_map" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_partial_map"].shape[0] + < reduce_partial_map_size_val + ): self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( - reduce_partial_map_size, dtype=reduce_partial_map_type, device=self.device + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=self.device, ) else: self.pa_metadata_buffers["reduce_partial_map"].zero_() def _build_pa_metadata_for_decode( - self, - batch_size: int, + self, + batch_size: int, tp_q_head_num: Optional[int] = None, ): """Build pa_metadata buffers for pa_persistent_fwd in decode mode. - + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. The metadata can be reused across multiple layers in the same forward pass. - + Args: batch_size: Batch size for the current forward pass tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. """ max_qlen = 1 - + # Use provided tp_q_head_num or default to self.num_head if tp_q_head_num is None: tp_q_head_num = self.num_head - + # kv_dtype_for_metadata = dtypes.fp8 ( (work_metadata_ptrs_size, work_metadata_ptrs_type), @@ -648,22 +710,22 @@ def _build_pa_metadata_for_decode( reduce_partial_map_type, ) qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] - + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) # Note: kv_lens comes from self.seq_lens which is already int32 context_lens = self.forward_metadata.kv_lens - + kernel_block_size = self.page_size num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync # page_table shape: [batch_size, max_num_blocks_per_seq] # Note: page_table comes from self.page_table which is already int32 and always set before this call page_table = self.forward_metadata.page_table - + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) create_flashinfer_kv_indices_triton[(batch_size,)]( page_table, @@ -719,7 +781,7 @@ def _build_pa_metadata_for_prefill(self, batch_size: int): # Page-level kv_indptr (reuse pa_kv_indptr buffer) pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - + # Build kv_indices from page_table using triton kernel page_table = self.forward_metadata.page_table create_flashinfer_kv_indices_triton[(batch_size,)]( @@ -759,11 +821,11 @@ def init_cuda_graph_state( # Always use preshuffle layout for pa_fwd_asm self.page_table = torch.zeros( - (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device - ) - self.seq_lens = torch.zeros( - (max_bs,), dtype=torch.int32, device=self.device + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) self.strided_indices = torch.arange( 0, self.max_context_len, self.page_size, device=self.device ) @@ -1045,8 +1107,12 @@ def set_kv_buffer_with_layout_shuffle( num_slots, num_kv_heads, head_dim = k_buffer.shape num_blocks = num_slots // block_size num_slots_with_block = num_blocks * block_size - k_buffer = k_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) - v_buffer = v_buffer[:num_slots_with_block].view(num_blocks, block_size, num_kv_heads, head_dim) + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) reshape_and_cache_shuffle_triton( k, v, @@ -1466,7 +1532,16 @@ def forward_decode_pa( k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) if self.use_mla: k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -1505,13 +1580,19 @@ def forward_decode_pa( ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) block_size = self.page_size num_slots, num_kv_heads, head_size = k_buffer.shape num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) x = 16 // k_buffer.element_size() k_cache_template = torch.empty( @@ -1555,7 +1636,11 @@ def forward_decode_pa_ps( # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) # Use q.shape[0] instead of forward_batch.batch_size to be safe batch_size = q.shape[0] - head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + head_dim_out = ( + layer.v_head_dim + if layer.qk_head_dim != layer.v_head_dim + else layer.head_dim + ) o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) if save_kv_cache: @@ -1605,13 +1690,18 @@ def forward_decode_pa_ps( num_kv_splits=num_kv_splits, ) else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) num_slots, num_kv_heads, head_size = k_buffer.shape block_size = self.page_size num_blocks = num_slots // block_size - k_buffer = k_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - v_buffer = v_buffer[:num_blocks * block_size].view(num_blocks, block_size, num_kv_heads, head_size) - + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) quant_dtype = dtypes.fp8 x = 16 // quant_dtype.itemsize @@ -1628,27 +1718,35 @@ def forward_decode_pa_ps( ) new_key_cache = k_buffer.view_as(k_cache_template) new_value_cache = v_buffer.view_as(v_cache_template) - + total_tokens = num_blocks * block_size k_qscale = self.k_qscale[:, :total_tokens] v_qscale = self.v_qscale[:, :total_tokens] - + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) - - assert self.forward_metadata.pa_metadata_qo_indptr is not None, "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_pages_kv_indptr is not None, "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_kv_indices is not None, "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_context_lens is not None, "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" - assert self.forward_metadata.pa_metadata_max_qlen is not None, "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" - + assert ( + self.forward_metadata.pa_metadata_qo_indptr is not None + ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_pages_kv_indptr is not None + ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_kv_indices is not None + ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_context_lens is not None + ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_max_qlen is not None + ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr kv_indices = self.forward_metadata.pa_metadata_kv_indices context_lens = self.forward_metadata.pa_metadata_context_lens max_qlen = self.forward_metadata.pa_metadata_max_qlen - - + _, _ = pa_persistent_fwd( Q=q, K=new_key_cache, @@ -1667,25 +1765,26 @@ def forward_decode_pa_ps( K_QScale=k_qscale, V_QScale=v_qscale, softmax_scale=layer.scaling, - mask=1, + mask=1, ) return o.view(-1, layer.tp_q_head_num * head_dim_out) - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): if self.use_mla: return self._forward_decode_mla(q, k, v, layer, forward_batch, save_kv_cache) else: if self.decode_using_pa_ps: - return self.forward_decode_pa_ps(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_decode_pa_ps( + q, k, v, layer, forward_batch, save_kv_cache + ) else: return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) From 215aab85b19afe7caf3073619077c1912cf8c799 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 9 Mar 2026 16:13:10 +0000 Subject: [PATCH 03/15] arg parse for format launch --- atom/plugin/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 299b905c5..f811b64b2 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -127,10 +127,13 @@ def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.configs.load_config import LoadConfig from atom.config import Config, ParallelConfig, CompilationConfig + # Format1: sglang serve --model-path ... + # Format2: python3 -m sglang.launch_server --model-path ... + args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] # sglang has no global config variable like vllm, # so here construct the server args from sys.argv passed by users # this is the only way to get full arguments - server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + server_args: ServerArgs = prepare_server_args(args_list) sgl_model_config = SglangModelConfig.from_server_args(server_args) sgl_model_opt_config = ModelOptConfig( @@ -222,7 +225,6 @@ def generate_atom_config_for_plugin_mode(config: Any = None): """ logger.info("Generate atom config for plugin mode from passed config") - atom_config = None from atom.plugin import is_vllm, is_sglang from atom.config import set_current_atom_config From 96af643f3560b4c71f56a596cea0d9bb9b2ad4c9 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 19 Mar 2026 14:46:53 +0000 Subject: [PATCH 04/15] remove print logging --- atom/models/deepseek_v2.py | 135 ++++++------------ .../attention_backend/sgl_attn_backend.py | 60 ++++---- 2 files changed, 67 insertions(+), 128 deletions(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index fe84de8ed..98f2f649d 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1507,30 +1507,7 @@ def _forward_sgl_prepare( llama_4_scaling = model_kwargs.get("llama_4_scaling", None) q_lora = None topk_indices = None - # #region agent log - try: - _pos_0 = int(positions.shape[0]) - _hs_0 = int(hidden_states.shape[0]) if hasattr(hidden_states, "shape") else -1 - _tp = int(get_tensor_model_parallel_world_size()) if forward_batch is not None else -1 - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "A", "location": "deepseek_v2.py:_forward_sgl_prepare_entry", "message": "prepare_entry", "data": {"positions_dim0": _pos_0, "hidden_states_dim0": _hs_0, "tp_world": _tp, "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion if self.q_lora_rank is not None: - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"positions={tuple(positions.shape)} hidden_states={tuple(hidden_states.shape)} " - f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)}" - ) - # qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) - # q, latent_cache = torch.split( - # qkv_lora, - # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - # dim=-1 - # ) - q, latent_cache = ( get_attn_tp_context() .fetch_qkv_latent() @@ -1539,20 +1516,6 @@ def _forward_sgl_prepare( dim=-1, ) ) - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"fetched q={tuple(q.shape)} latent={tuple(latent_cache.shape)} " - f"positions={tuple(positions.shape)} tp_world={get_tensor_model_parallel_world_size()}" - ) - # #region agent log - try: - _q0, _p0, _tp = int(q.shape[0]), int(positions.shape[0]), int(get_tensor_model_parallel_world_size()) - _fallback_cond = _q0 != _p0 and _tp > 1 - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "B,C", "location": "deepseek_v2.py:after_fetch", "message": "after_fetch", "data": {"q_dim0": _q0, "positions_dim0": _p0, "tp_world": _tp, "fallback_will_run": _fallback_cond}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: qkv_lora = torch.cat([q, latent_cache], dim=-1) @@ -1567,17 +1530,6 @@ def _forward_sgl_prepare( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"after_fallback_gather q={tuple(q.shape)} latent={tuple(latent_cache.shape)}" - ) - # #region agent log - try: - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "C", "location": "deepseek_v2.py:after_fallback", "message": "after_fallback", "data": {"q_dim0": int(q.shape[0]), "positions_dim0": int(positions.shape[0])}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion k_nope = latent_cache[..., : self.kv_lora_rank] @@ -1687,18 +1639,11 @@ def _forward_sgl_prepare( q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) - # #region agent log - try: - with open("/mnt/raid0/zhiyan/plugin_gb/.cursor/debug-17e017.log", "a") as _f: - _f.write(json.dumps({"sessionId": "17e017", "hypothesisId": "D,E", "location": "deepseek_v2.py:before_rope_assert", "message": "before_rope", "data": {"q_pe_dim0": int(q_pe.shape[0]), "k_pe_dim0": int(k_pe.shape[0]), "positions_dim0": int(positions.shape[0]), "q_lora_rank": getattr(self, "q_lora_rank", None)}, "timestamp": __import__("time").time_ns() // 1000000}) + "\n") - except Exception: - pass - # #endregion - print( - f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " - f"positions={tuple(positions.shape)}" - ) + # print( + # f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " + # f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " + # f"positions={tuple(positions.shape)}" + # ) _is_hip= True if self.use_deep_gemm_bmm: @@ -1743,7 +1688,7 @@ def _forward_sgl_prepare( q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( X=q_nope, WQ=self.w_kc.transpose(-1, -2), - w_scale=self.mla_attn.w_scale, + w_scale=self.w_scale, group_size=128, YQ=None, # allocate (B, M, N) transpose_bm=False, # (B, M, N) @@ -1924,15 +1869,15 @@ def prepare_qkv_latent( hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " - f"hidden_states={tuple(hidden_states.shape)} " - f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " + # f"hidden_states={tuple(hidden_states.shape)} " + # f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " + # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + # f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" + # ) qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) - print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") + # print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") # Fallback: when communicator does not enable input_scattered gather, # force qkv latent token dimension to align with positions. @@ -1949,11 +1894,11 @@ def prepare_qkv_latent( and qkv_lora.shape[0] != expected_tokens and get_tensor_model_parallel_world_size() > 1 ): - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " - f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " - f"tp_world={get_tensor_model_parallel_world_size()}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " + # f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " + # f"tp_world={get_tensor_model_parallel_world_size()}" + # ) qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) if qkv_lora.shape[0] > expected_tokens: qkv_lora = qkv_lora[:expected_tokens] @@ -1962,9 +1907,9 @@ def prepare_qkv_latent( f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " f"expected {expected_tokens}" ) - print( - f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" - ) + # print( + # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" + # ) return qkv_lora @@ -1981,19 +1926,19 @@ def forward_sgl_plugin_mode( raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") attn_tp_context = get_attn_tp_context() - print( - f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - f"positions={tuple(positions.shape)} " - f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " - f"allow_scatter={attn_tp_context.allow_input_scattered}" - ) + # print( + # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + # f"positions={tuple(positions.shape)} " + # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " + # f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " + # f"allow_scatter={attn_tp_context.allow_input_scattered}" + # ) with attn_tp_context.maybe_input_scattered(forward_batch): - print( - f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - f"input_scattered={attn_tp_context.input_scattered}" - ) + # print( + # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " + # f"input_scattered={attn_tp_context.input_scattered}" + # ) if self.q_lora_rank is not None: attn_tp_context.set_attn_inputs( AttentionInputs( @@ -2404,12 +2349,12 @@ def forward( **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: # Self Attention - print( - f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " - f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - f"residual={None if residual is None else tuple(residual.shape)} " - f"fuse_input_norm_quant={self.fuse_input_norm_quant}" - ) + # print( + # f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " + # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " + # f"residual={None if residual is None else tuple(residual.shape)} " + # f"fuse_input_norm_quant={self.fuse_input_norm_quant}" + # ) if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 7f2216511..fe935fe8e 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -213,13 +213,7 @@ def __init__( self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building -<<<<<<< HEAD # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." -======= - assert ( - not self.use_mla - ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." ->>>>>>> bfc8900 (make format happy) # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] @@ -1818,33 +1812,33 @@ def _forward_decode_mla(self, q, k, v, layer, forward_batch, save_kv_cache): _q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) _k_view = k_buffer.view(-1, 1, 1, layer.qk_head_dim) _o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) - print( - f"[MLA_DECODE_DBG] layer=0" - f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" - f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" - f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" - f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" - f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" - f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" - f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" - f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" - f" max_q_len={self.forward_metadata.max_q_len}" - f" sm_scale={layer.scaling}" - f" logit_cap={layer.logit_cap}" - f" k_scale={layer.k_scale}" - f" num_kv_splits={num_kv_splits}" - f" page_size={self.page_size}" - f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" - f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" - f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" - f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" - f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" - f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" - f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" - f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" - f" fast_mode={_sglang_aiter.fast_mode}" - , flush=True, - ) + # print( + # f"[MLA_DECODE_DBG] layer=0" + # f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" + # f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" + # f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" + # f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" + # f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" + # f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" + # f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" + # f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" + # f" max_q_len={self.forward_metadata.max_q_len}" + # f" sm_scale={layer.scaling}" + # f" logit_cap={layer.logit_cap}" + # f" k_scale={layer.k_scale}" + # f" num_kv_splits={num_kv_splits}" + # f" page_size={self.page_size}" + # f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" + # f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" + # f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" + # f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" + # f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" + # f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" + # f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" + # f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" + # f" fast_mode={_sglang_aiter.fast_mode}" + # , flush=True, + # ) mla_decode_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), From 29a9759735201f475f6e3c1c7a348d83e9de5d03 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 03:11:14 +0000 Subject: [PATCH 05/15] add Qwen3-235B support for sgl_oot Signed-off-by: zhuyuhua-v --- atom/plugin/sglang/oot/qwen3_moe.py | 97 +++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 atom/plugin/sglang/oot/qwen3_moe.py diff --git a/atom/plugin/sglang/oot/qwen3_moe.py b/atom/plugin/sglang/oot/qwen3_moe.py new file mode 100644 index 000000000..b998eaf50 --- /dev/null +++ b/atom/plugin/sglang/oot/qwen3_moe.py @@ -0,0 +1,97 @@ +"""ATOM Qwen model wrapper for SGLang external model loading. + +Registers Qwen3MoeForCausalLM and Qwen3ForCausalLM as external +model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's +built-in implementations with ATOM-optimized versions. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + + +class Qwen3MoeForCausalLM(nn.Module): + """ATOM-backed Qwen3 MoE model for SGLang. + + This wrapper delegates model creation and weight loading to ATOM's + plugin system, while conforming to sglang's model interface + (forward signature, LogitsProcessorOutput return type, load_weights). + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +class Qwen3ForCausalLM(Qwen3MoeForCausalLM): + pass + + +EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM] From 2f9aa0b0c31efa3c77230ef462aa622a1116bbf8 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 03:16:47 +0000 Subject: [PATCH 06/15] add Deepseek-R1 support for sgl_oot Signed-off-by: zhuyuhua-v --- atom/plugin/sglang/oot/deepseek_v2.py | 97 +++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 atom/plugin/sglang/oot/deepseek_v2.py diff --git a/atom/plugin/sglang/oot/deepseek_v2.py b/atom/plugin/sglang/oot/deepseek_v2.py new file mode 100644 index 000000000..e8f4444e2 --- /dev/null +++ b/atom/plugin/sglang/oot/deepseek_v2.py @@ -0,0 +1,97 @@ +"""ATOM DeepSeek model wrapper for SGLang external model loading. + +Registers DeepseekV3ForCausalLM and DeepseekV2ForCausalLM as external +model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's +built-in implementations with ATOM-optimized versions. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + + +class DeepseekV2ForCausalLM(nn.Module): + """ATOM-backed DeepSeek v2/v3 model for SGLang. + + This wrapper delegates model creation and weight loading to ATOM's + plugin system, while conforming to sglang's model interface + (forward signature, LogitsProcessorOutput return type, load_weights). + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] From bcb63f244d225b1755201edee981b089b908b91d Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 23 Mar 2026 07:28:08 +0000 Subject: [PATCH 07/15] unify base model wrapper Signed-off-by: zhuyuhua-v --- .../{deepseek_v2.py => base_model_wrapper.py} | 36 ++++--- atom/plugin/sglang/oot/qwen3_moe.py | 97 ------------------- 2 files changed, 22 insertions(+), 111 deletions(-) rename atom/plugin/sglang/oot/{deepseek_v2.py => base_model_wrapper.py} (74%) delete mode 100644 atom/plugin/sglang/oot/qwen3_moe.py diff --git a/atom/plugin/sglang/oot/deepseek_v2.py b/atom/plugin/sglang/oot/base_model_wrapper.py similarity index 74% rename from atom/plugin/sglang/oot/deepseek_v2.py rename to atom/plugin/sglang/oot/base_model_wrapper.py index e8f4444e2..97eff70d4 100644 --- a/atom/plugin/sglang/oot/deepseek_v2.py +++ b/atom/plugin/sglang/oot/base_model_wrapper.py @@ -1,8 +1,9 @@ -"""ATOM DeepSeek model wrapper for SGLang external model loading. +"""ATOM model wrappers for SGLang external model loading (OOT). -Registers DeepseekV3ForCausalLM and DeepseekV2ForCausalLM as external -model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's -built-in implementations with ATOM-optimized versions. +Registers model architecture classes via SGLANG_EXTERNAL_MODEL_PACKAGE, +replacing sglang's built-in implementations with ATOM-optimized versions. + +To add a new model, append its architecture class name to _MODEL_NAMES. """ import logging @@ -18,13 +19,20 @@ logger = logging.getLogger("atom.plugin.sglang.oot") +_MODEL_NAMES = [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3ForCausalLM", +] + -class DeepseekV2ForCausalLM(nn.Module): - """ATOM-backed DeepSeek v2/v3 model for SGLang. +class _AtomCausalLMBaseForSglangOOT(nn.Module): + """Base ATOM model wrapper conforming to sglang's model interface. - This wrapper delegates model creation and weight loading to ATOM's - plugin system, while conforming to sglang's model interface - (forward signature, LogitsProcessorOutput return type, load_weights). + Delegates model creation and weight loading to ATOM's plugin system, + while providing the forward signature and LogitsProcessorOutput return + type that sglang expects. """ def __init__( @@ -90,8 +98,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) -class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): - pass - - -EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] +EntryClass = [] +for _name in _MODEL_NAMES: + _cls = type(_name, (_AtomCausalLMBaseForSglangOOT,), {}) + globals()[_name] = _cls + EntryClass.append(_cls) diff --git a/atom/plugin/sglang/oot/qwen3_moe.py b/atom/plugin/sglang/oot/qwen3_moe.py deleted file mode 100644 index b998eaf50..000000000 --- a/atom/plugin/sglang/oot/qwen3_moe.py +++ /dev/null @@ -1,97 +0,0 @@ -"""ATOM Qwen model wrapper for SGLang external model loading. - -Registers Qwen3MoeForCausalLM and Qwen3ForCausalLM as external -model classes via SGLANG_EXTERNAL_MODEL_PACKAGE, replacing sglang's -built-in implementations with ATOM-optimized versions. -""" - -import logging -from typing import Iterable, Optional, Tuple, Union - -import torch -from torch import nn - -from sglang.srt.distributed import get_pp_group -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors - -logger = logging.getLogger("atom.plugin.sglang.oot") - - -class Qwen3MoeForCausalLM(nn.Module): - """ATOM-backed Qwen3 MoE model for SGLang. - - This wrapper delegates model creation and weight loading to ATOM's - plugin system, while conforming to sglang's model interface - (forward signature, LogitsProcessorOutput return type, load_weights). - """ - - def __init__( - self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - logger.info("Initializing ATOM backend for %s", self.__class__.__name__) - - self.pp_group = get_pp_group() - self.quant_config = quant_config - self.config = config - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size - - import atom - - self.model = atom.prepare_model(config=config, engine="sglang") - if self.model is None: - model_arch = getattr(config, "architectures", ["unknown"])[0] - raise ValueError( - f"ATOM failed to create model for architecture {model_arch}" - ) - - self.logits_processor = LogitsProcessor(config) - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, - get_embedding: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=None, - inputs_embeds=input_embeds, - forward_batch=forward_batch, - get_embedding=get_embedding, - pp_proxy_tensors=pp_proxy_tensors, - ) - - if self.pp_group.is_last_rank: - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, - ) - return hidden_states - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - from atom.model_loader.loader import load_model_in_plugin_mode - - return load_model_in_plugin_mode( - model=self.model, config=self.model.atom_config, prefix="model." - ) - - -class Qwen3ForCausalLM(Qwen3MoeForCausalLM): - pass - - -EntryClass = [Qwen3MoeForCausalLM, Qwen3ForCausalLM] From 24bc65f943028543f547fe2ec2faf02b8c0ff3c8 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 20 Mar 2026 05:40:58 +0000 Subject: [PATCH 08/15] enable fuse_rope_cat_and_cache_mla on gfx950 --- atom/model_ops/radix_attention.py | 3 ++- atom/models/deepseek_v2.py | 43 ++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 4340311b0..9ebffce1e 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -94,6 +94,7 @@ def forward_impl_plugin_mode( if is_sglang(): # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) + save_kv_cache = kwargs.get("save_kv_cache", not self.use_aiter_rope_fused_qknorm) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM return self.attn( @@ -101,7 +102,7 @@ def forward_impl_plugin_mode( key, value, forward_batch=forward_batch, - save_kv_cache=not self.use_aiter_rope_fused_qknorm, + save_kv_cache=save_kv_cache, ) else: raise NotImplementedError( diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 98f2f649d..aef5fc625 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -50,6 +50,7 @@ fused_reduce_rms_mxfp4_quant, fused_rms_mxfp4_quant, ) +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.rotary_embedding import get_rope from atom.config import ( @@ -1480,6 +1481,7 @@ def __init__( self.use_nsa = is_deepseek_nsa(config) self.use_deep_gemm_bmm = False self.alt_stream = None + self.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) @@ -1720,12 +1722,7 @@ def _forward_sgl_prepare( q_nope_out = q_nope_out.transpose(0, 1) - if ( - self.rotary_emb is not None - # and (not self._fuse_rope_for_trtllm_mla(forward_batch)) - and (not _use_aiter or not _is_gfx95_supported or self.use_nsa) - ): - # Optional hard check during debugging + if self.rotary_emb is not None and not self.use_fused_qk_rope_concat_and_cache_mla: assert q_pe.shape[0] == positions.shape[0], ( f"q_pe tokens {q_pe.shape[0]} != positions {positions.shape[0]}" ) @@ -1765,9 +1762,35 @@ def _forward_sgl_core( llama_4_scaling, ): # 1) build q/k for radix attention path - _is_hip = True - q = torch.cat([q_nope_out, q_pe], dim=-1) - k = torch.cat([k_nope, k_pe], dim=-1) + save_kv_cache = True + + if self.use_fused_qk_rope_concat_and_cache_mla: + cos = self.rotary_emb.cos_cache + sin = self.rotary_emb.sin_cache + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( + self.layer_num + ) + k_scale = self.mla_attn.attn.k_scale + + q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( + q_nope_out, + q_pe, + k_nope, + k_pe, + kv_cache, + forward_batch.out_cache_loc, + positions, + cos, + sin, + k_scale, + self.rotary_emb.is_neox_style, + q_out_dtype=q_nope_out.dtype, + ) + k = torch.cat([k_nope, k_pe_roped], dim=-1) + save_kv_cache = False + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) if llama_4_scaling is not None: q = q * llama_4_scaling @@ -1778,7 +1801,7 @@ def _forward_sgl_core( k, k_nope, forward_batch=forward_batch, - save_kv_cache=True, + save_kv_cache=save_kv_cache, **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) From 288f81c8c1720e49c7af8cf00f3d848171385e73 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Fri, 20 Mar 2026 07:55:10 +0000 Subject: [PATCH 09/15] add _is_hip = True --- atom/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index aef5fc625..ddbee0e74 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1807,6 +1807,7 @@ def _forward_sgl_core( attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) # 3) up-proj by w_vc (port from sglang forward_absorb_core) + _is_hip = True if self.use_deep_gemm_bmm: attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( per_token_group_quant_mla_deep_gemm_masked_fp8(attn_output.transpose(0, 1)) From d626960441ec25a78fe05a74b85c1e48aef2ab9c Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 23 Mar 2026 14:17:35 +0000 Subject: [PATCH 10/15] fix AiterTensor check Signed-off-by: zhuyuhua-v --- atom/model_ops/radix_attention.py | 6 ++++-- atom/plugin/attention_backend/sgl_attn_backend.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 9ebffce1e..d0a9f0453 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -65,11 +65,13 @@ def __init__( ) if self.attn.k_scale is None: self.attn.k_scale = torch.nn.Parameter( - torch.tensor([1.0], dtype=torch.float32), requires_grad=False + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, ) if self.attn.v_scale is None: self.attn.v_scale = torch.nn.Parameter( - torch.tensor([1.0], dtype=torch.float32), requires_grad=False + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, ) else: raise NotImplementedError( diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index fe935fe8e..4a46f3dfb 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -803,7 +803,7 @@ def init_cuda_graph_state( max_num_tokens: int, kv_indices_buf: Optional[torch.Tensor] = None, ): - self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int, device=self.device) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), From d6e11b31d407b74fad6ca19c0a628c9d6c62ce8c Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Wed, 25 Mar 2026 07:13:27 +0000 Subject: [PATCH 11/15] clean mla code path, keep the origin design Signed-off-by: zhuyuhua-v --- atom/config.py | 2 +- atom/model_ops/linear.py | 10 - atom/model_ops/moe.py | 25 - atom/models/deepseek_v2.py | 410 ++------- atom/models/qwen3_moe.py | 2 +- .../attention_backend/sgl_attn_backend.py | 786 +++++------------- 6 files changed, 279 insertions(+), 956 deletions(-) diff --git a/atom/config.py b/atom/config.py index 089156244..94a268d8d 100644 --- a/atom/config.py +++ b/atom/config.py @@ -916,7 +916,7 @@ def __post_init__(self): # Compatible with both transformers < 5 rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} rope_params["rope_theta"] = self.hf_config.rope_theta - rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") + rope_params["rope_type"] = rope_params.get("rope_type", "default") self.hf_config.rope_parameters = rope_params self.quant_config = QuantizationConfig( self.hf_config, diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index ec716a755..a25f4070d 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -394,21 +394,11 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) - _diag_forward_counter = 0 @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 ) -> torch.Tensor: - if LinearBase._diag_forward_counter < 5: - LinearBase._diag_forward_counter += 1 - print( - f"[DIAG][LinearBase.forward] prefix={self.prefix} " - f"quant_type={self.quant_type} " - f"w_dtype={self.weight.dtype} " - f"w_scale_shape={tuple(self.weight_scale.shape) if self.weight_scale is not None else None} " - f"x shape={tuple(x.shape)}" - ) if self.quant_type.value == QuantType.No.value: y = tgemm.mm( x, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index c83f55884..ca0416f06 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1610,20 +1610,7 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: def _process_block_quant(self, layer: nn.Module) -> None: assert self.quant_config["is_dynamic"] - print( - f"[DIAG][Fp8MoE._process_block_quant] BEFORE normalize: " - f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " - f"w13_scale dtype={layer.w13_weight_scale.dtype} shape={tuple(layer.w13_weight_scale.shape)} " - f"need_normalize={self.need_normalize_e4m3fn_to_e4m3fnuz}" - ) self._normalize_weights_and_scales(layer) - print( - f"[DIAG][Fp8MoE._process_block_quant] AFTER normalize: " - f"w13 dtype={layer.w13_weight.dtype} " - f"w13_scale min={layer.w13_weight_scale.data.min().item():.6f} " - f"max={layer.w13_weight_scale.data.max().item():.6f} " - f"mean={layer.w13_weight_scale.data.float().mean().item():.6f}" - ) if not self.need_normalize_e4m3fn_to_e4m3fnuz: layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False) @@ -1636,7 +1623,6 @@ def _process_block_quant(self, layer: nn.Module) -> None: ) shuffle_weights(layer.w13_weight, layer.w2_weight) - print(f"[DIAG][Fp8MoE._process_block_quant] DONE shuffle") def _process_channel_quant(self, layer: nn.Module) -> None: """PTPTC""" @@ -1767,17 +1753,6 @@ def apply( # per_Tensor doesn't support num_local_tokens, so fallback to # rocm_aiter_fused_moe when using per-tensor or no modular kernel. if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: - if not hasattr(self, "_diag_apply_printed"): - self._diag_apply_printed = True - print( - f"[DIAG][Fp8MoE.apply] rocm_aiter path: " - f"quant_type={self.quant_type} " - f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} " - f"w13_scale shape={tuple(layer.w13_weight_scale.shape)} " - f"w13_scale min={layer.w13_weight_scale.data.float().min().item():.6f} " - f"max={layer.w13_weight_scale.data.float().max().item():.6f} " - f"x dtype={x.dtype} shape={tuple(x.shape)}" - ) return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 5bccdc73c..f5535bc31 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -92,7 +92,7 @@ from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.configs.model_config import is_deepseek_nsa -from sglang.srt.models.deepseek_common.utils import _use_aiter_gfx95,_use_aiter,_is_gfx95_supported +from sglang.srt.models.deepseek_common.utils import _use_aiter_gfx95, _use_aiter, _is_gfx95_supported, _is_hip from sglang.srt.layers.quantization.rocm_mxfp4_utils import batched_gemm_afp4wfp4_pre_quant from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, @@ -1498,21 +1498,75 @@ def __init__( self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - # for sglang + # for sglang plugin mode self.use_nsa = is_deepseek_nsa(config) self.use_deep_gemm_bmm = False self.alt_stream = None self.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 - # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( - # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) - # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) self.w_kc, self.w_vc = None, None self.w_scale = None self.w_scale_k = None self.w_scale_v = None - # self.w_kc, self.w_vc = self.kv_b_proj.weight.data.unflatten( - # 0, (-1, self.qk_nope_head_dim + self.v_head_dim) - # ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + + def _mla_absorbed_bmm(self, inp, weight, weight_scale, weight_scale_k, out_dim): + """Shared batched matmul for MLA absorbed weights (w_kc / w_vc). + + Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. + inp: (num_tokens, num_heads, in_dim) — token-major + Returns: (num_tokens, num_heads, out_dim) — token-major + """ + if self.use_deep_gemm_bmm: + val, scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) + ) + out = inp.new_empty((self.num_local_heads, aligned_m, out_dim)) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (val, scale), (weight, weight_scale_k), out, masked_m, expected_m, + ) + return out[:, :expected_m, :].transpose(0, 1) + + if _is_hip: + if _use_aiter_gfx95 and weight.dtype == torch.uint8: + x = inp.transpose(0, 1) + out = torch.empty( + x.shape[0], x.shape[1], weight.shape[2], + device=x.device, dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, weight.transpose(-2, -1), + weight_scale_k.transpose(-2, -1), + torch.bfloat16, out, + ) + return out.transpose(0, 1) + + if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz + ): + out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=inp, WQ=weight.transpose(-1, -2), + w_scale=weight_scale, group_size=128, + YQ=None, transpose_bm=False, transpose_bm_in=True, + dtype=torch.bfloat16, + ) + return out.transpose(0, 1) + + out = torch.bmm( + inp.to(torch.bfloat16).transpose(0, 1), + weight.to(torch.bfloat16) * weight_scale, + ) + return out.transpose(0, 1) + + # CUDA fp8 path + if weight.dtype == torch.float8_e4m3fn: + val, scale = per_tensor_quant_mla_fp8( + inp.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=inp.device), + ) + out = bmm_fp8(val, weight, scale, weight_scale, torch.bfloat16) + return out.transpose(0, 1) + + # bf16 fallback + return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) def _forward_sgl_prepare( self, @@ -1520,11 +1574,10 @@ def _forward_sgl_prepare( hidden_states: torch.Tensor, **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: - # supplementary code, port from forward_common hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states - + forward_batch = model_kwargs.get("forward_batch", None) zero_allocator = model_kwargs.get("zero_allocator", None) llama_4_scaling = model_kwargs.get("llama_4_scaling", None) @@ -1565,56 +1618,9 @@ def _forward_sgl_prepare( k_nope = self.kv_a_layernorm(k_nope) current_stream.wait_stream(self.alt_stream) else: - # if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: - # q, _, k_nope, *_ = fused_rms_mxfp4_quant( - # q, - # self.q_a_layernorm.weight, - # self.q_a_layernorm.variance_epsilon, - # k_nope, - # self.kv_a_layernorm.weight, - # self.kv_a_layernorm.variance_epsilon, - # ) - # else: - q_lora = None - _use_aiter_gfx95 = False - if ( - _use_aiter_gfx95 - and - self.q_b_proj.weight.dtype == torch.float8_e4m3fn - ): - if self.use_nsa: - q_quanted, q_lora, k_nope, _ = fused_rms_fp8_group_quant( - q, - self.q_a_layernorm.weight, - self.q_a_layernorm.variance_epsilon, - k_nope, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.variance_epsilon, - group_size=128, - dtype_quant=torch.float8_e4m3fn, - res1=None, - output_unquantized_inp1=True, - ) - q = q_quanted - else: - q, _, k_nope, _ = fused_rms_fp8_group_quant( - q, - self.q_a_layernorm.weight, - self.q_a_layernorm.variance_epsilon, - k_nope, - self.kv_a_layernorm.weight, - self.kv_a_layernorm.variance_epsilon, - group_size=128, - dtype_quant=torch.float8_e4m3fn, - res1=None, - output_unquantized_inp1=False, - ) - - else: - q = self.q_a_layernorm(q) - k_nope = self.kv_a_layernorm(k_nope) + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) - # q_lora needed by indexer if self.use_nsa: if q_lora is None: q_lora = q @@ -1634,11 +1640,8 @@ def _forward_sgl_prepare( -1, self.num_local_heads, self.qk_head_dim ) topk_indices = self.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=self.layer_num, + x=hidden_states, q_lora=q_lora, positions=positions, + forward_batch=forward_batch, layer_id=self.layer_num, ) current_stream.wait_stream(self.alt_stream) else: @@ -1646,11 +1649,8 @@ def _forward_sgl_prepare( q = self.q_b_proj(q).view(-1, self.num_local_heads, self.qk_head_dim) if q_lora is not None: topk_indices = self.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=self.layer_num, + x=hidden_states, q_lora=q_lora, positions=positions, + forward_batch=forward_batch, layer_id=self.layer_num, ) else: q = self.q_proj(hidden_states).view( @@ -1662,149 +1662,41 @@ def _forward_sgl_prepare( q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) - # print( - # f"[MLA_DBG][_forward_sgl_prepare][layer={self.layer_num}] " - # f"q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_pe={tuple(k_pe.shape)} " - # f"positions={tuple(positions.shape)}" - # ) - _is_hip= True - if self.use_deep_gemm_bmm: - q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( - per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1)) - ) - q_nope_out = q_nope.new_empty( - (self.num_local_heads, aligned_m, self.kv_lora_rank) - ) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - (q_nope_val, q_nope_scale), - (self.w_kc, self.w_scale_k), - q_nope_out, - masked_m, - expected_m, - ) - q_nope_out = q_nope_out[:, :expected_m, :] - elif _is_hip: - # TODO(haishaw): add bmm_fp8 to ROCm - if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: - x = q_nope.transpose(0, 1) - q_nope_out = torch.empty( - x.shape[0], - x.shape[1], - self.w_kc.shape[2], - device=x.device, - dtype=torch.bfloat16, - ) - batched_gemm_afp4wfp4_pre_quant( - x, - self.w_kc.transpose(-2, -1), - self.w_scale_k.transpose(-2, -1), - torch.bfloat16, - q_nope_out, - ) - else: - if (_use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn) or ( - get_is_capture_mode() and self.w_kc.dtype == torch.float8_e4m3fnuz - ): - # fp8 Triton kernel: always on gfx950, - # cudagraph-only on gfx942 (hides launch overhead) - q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( - X=q_nope, - WQ=self.w_kc.transpose(-1, -2), - w_scale=self.w_scale, - group_size=128, - YQ=None, # allocate (B, M, N) - transpose_bm=False, # (B, M, N) - transpose_bm_in=True, # (M, B, K) - dtype=torch.bfloat16, - ) - - else: - q_nope_out = torch.bmm( - q_nope.to(torch.bfloat16).transpose(0, 1), - self.w_kc.to(torch.bfloat16) * self.w_scale, - ) - - elif self.w_kc.dtype == torch.float8_e4m3fn: - # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 - q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), - ( - torch.zeros((1,), dtype=torch.float32, device=q_nope.device) - # if _is_cublas_ge_129 - # else zero_allocator.allocate(1) - ), - ) - q_nope_out = bmm_fp8( - q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 - ) - else: - q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) - - q_nope_out = q_nope_out.transpose(0, 1) + q_nope_out = self._mla_absorbed_bmm( + q_nope, self.w_kc, self.w_scale, self.w_scale_k, self.kv_lora_rank, + ) if self.rotary_emb is not None and not self.use_fused_qk_rope_concat_and_cache_mla: - assert q_pe.shape[0] == positions.shape[0], ( - f"q_pe tokens {q_pe.shape[0]} != positions {positions.shape[0]}" - ) - assert k_pe.shape[0] == positions.shape[0], ( - f"k_pe tokens {k_pe.shape[0]} != positions {positions.shape[0]}" - ) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) if nsa_use_prefill_cp(forward_batch): - # support allgather+rerrange k_nope, k_pe = self.rebuild_cp_kv_cache( latent_cache, forward_batch, k_nope, k_pe ) - # end forward prepare + return ( - q_pe, - k_pe, - q_nope_out, - k_nope, - forward_batch, - zero_allocator, - positions, - topk_indices, - llama_4_scaling, + q_pe, k_pe, q_nope_out, k_nope, + forward_batch, zero_allocator, positions, topk_indices, llama_4_scaling, ) def _forward_sgl_core( self, - q_pe, - k_pe, - q_nope_out, - k_nope, - forward_batch, - zero_allocator, - positions, - topk_indices, - llama_4_scaling, + q_pe, k_pe, q_nope_out, k_nope, + forward_batch, zero_allocator, positions, topk_indices, llama_4_scaling, ): - # 1) build q/k for radix attention path save_kv_cache = True if self.use_fused_qk_rope_concat_and_cache_mla: cos = self.rotary_emb.cos_cache sin = self.rotary_emb.sin_cache - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( - self.layer_num - ) + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(self.layer_num) k_scale = self.mla_attn.attn.k_scale q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( - q_nope_out, - q_pe, - k_nope, - k_pe, - kv_cache, - forward_batch.out_cache_loc, - positions, - cos, - sin, - k_scale, - self.rotary_emb.is_neox_style, + q_nope_out, q_pe, k_nope, k_pe, + kv_cache, forward_batch.out_cache_loc, positions, + cos, sin, k_scale, self.rotary_emb.is_neox_style, q_out_dtype=q_nope_out.dtype, ) k = torch.cat([k_nope, k_pe_roped], dim=-1) @@ -1816,118 +1708,30 @@ def _forward_sgl_core( if llama_4_scaling is not None: q = q * llama_4_scaling - # 2) attention core attn_output = self.mla_attn( - q, - k, - k_nope, + q, k, k_nope, forward_batch=forward_batch, save_kv_cache=save_kv_cache, **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - # 3) up-proj by w_vc (port from sglang forward_absorb_core) - _is_hip = True - if self.use_deep_gemm_bmm: - attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = ( - per_token_group_quant_mla_deep_gemm_masked_fp8(attn_output.transpose(0, 1)) - ) - attn_bmm_output = attn_output.new_empty( - (self.num_local_heads, aligned_m, self.v_head_dim) - ) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - (attn_output_val, attn_output_scale), - (self.w_vc, self.w_scale_v), - attn_bmm_output, - masked_m, - expected_m, - ) - attn_bmm_output = ( - attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2) - ) + # up-proj by w_vc + attn_bmm_output = self._mla_absorbed_bmm( + attn_output, self.w_vc, self.w_scale, self.w_scale_v, self.v_head_dim, + ).flatten(1, 2) - elif _is_hip: - if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8: - x = attn_output.transpose(0, 1) - y = torch.empty( - x.shape[0], - x.shape[1], - self.w_vc.shape[2], - device=x.device, - dtype=torch.bfloat16, - ) - batched_gemm_afp4wfp4_pre_quant( - x, - self.w_vc.transpose(-2, -1), - self.w_scale_v.transpose(-2, -1), - torch.bfloat16, - y, - ) - attn_bmm_output = y.transpose(0, 1).flatten(1, 2) - else: - if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: - y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( - X=attn_output, - WQ=self.w_vc.transpose(-1, -2), - w_scale=self.w_scale, - group_size=128, - YQ=None, - transpose_bm=False, - transpose_bm_in=True, - dtype=torch.bfloat16, - ) - else: - y = torch.bmm( - attn_output.to(torch.bfloat16).transpose(0, 1), - self.w_vc.to(torch.bfloat16) * self.w_scale, - ) - attn_bmm_output = y.transpose(0, 1).flatten(1, 2) - - elif self.w_vc.dtype == torch.float8_e4m3fn: - attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), - torch.zeros((1,), dtype=torch.float32, device=attn_output.device), - ) - attn_bmm_output = bmm_fp8( - attn_output_val, - self.w_vc, - attn_output_scale, - self.w_scale, - torch.bfloat16, - ) - attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + return self.o_proj(attn_bmm_output) - else: - attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) - attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) - - output = self.o_proj(attn_bmm_output) - return output - - def prepare_qkv_latent( - self, - hidden_states: torch.Tensor, - forward_batch, - ): + def prepare_qkv_latent(self, hidden_states: torch.Tensor, forward_batch): assert self.q_lora_rank is not None hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states - # print( - # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] " - # f"hidden_states={tuple(hidden_states.shape)} " - # f"hs_scale={None if hidden_states_scale is None else tuple(hidden_states_scale.shape)} " - # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - # f"positions={None if getattr(forward_batch, 'positions', None) is None else tuple(forward_batch.positions.shape)}" - # ) qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) - # print(f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] qkv_lora={tuple(qkv_lora.shape)}") # Fallback: when communicator does not enable input_scattered gather, # force qkv latent token dimension to align with positions. - # Use positions.shape[0] (actual input token count) instead of - # seq_lens_sum (total KV cache length, wrong for decode mode). expected_tokens = 0 if hasattr(forward_batch, "positions") and forward_batch.positions is not None: expected_tokens = int(forward_batch.positions.shape[0]) @@ -1939,11 +1743,6 @@ def prepare_qkv_latent( and qkv_lora.shape[0] != expected_tokens and get_tensor_model_parallel_world_size() > 1 ): - # print( - # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] before_fallback_gather " - # f"qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens} " - # f"tp_world={get_tensor_model_parallel_world_size()}" - # ) qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) if qkv_lora.shape[0] > expected_tokens: qkv_lora = qkv_lora[:expected_tokens] @@ -1952,51 +1751,31 @@ def prepare_qkv_latent( f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " f"expected {expected_tokens}" ) - # print( - # f"[MLA_DBG][prepare_qkv_latent][layer={self.layer_num}] return_qkv_lora={tuple(qkv_lora.shape)} expected={expected_tokens}" - # ) return qkv_lora - def forward_sgl_plugin_mode( self, positions: torch.Tensor, hidden_states: torch.Tensor, **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: - # forward_absorb_prepare sglang from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode forward_batch = model_kwargs.get("forward_batch", None) if forward_batch is None: raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") attn_tp_context = get_attn_tp_context() - # print( - # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - # f"positions={tuple(positions.shape)} " - # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - # f"seq_lens_sum={getattr(forward_batch, 'seq_lens_sum', None)} " - # f"input_ids={None if getattr(forward_batch, 'input_ids', None) is None else tuple(forward_batch.input_ids.shape)} " - # f"allow_scatter={attn_tp_context.allow_input_scattered}" - # ) with attn_tp_context.maybe_input_scattered(forward_batch): - # print( - # f"[MLA_DBG][forward_sgl_plugin_mode][layer={self.layer_num}] " - # f"input_scattered={attn_tp_context.input_scattered}" - # ) if self.q_lora_rank is not None: attn_tp_context.set_attn_inputs( AttentionInputs( - hidden_states, - forward_batch, - self.prepare_qkv_latent, + hidden_states, forward_batch, self.prepare_qkv_latent, ) ) prepared = self._forward_sgl_prepare(positions, hidden_states, **model_kwargs) return self._forward_sgl_core(*prepared) def forward_common( - self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -2088,7 +1867,7 @@ def forward_common( k_pe, positions, hidden_states_or_q_c_scale, - ) + ) def forward( self, @@ -2393,13 +2172,6 @@ def forward( residual: Optional[torch.Tensor], **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: - # Self Attention - # print( - # f"[MLA_DBG][decoder_layer][layer={self.layer_idx}] positions={tuple(positions.shape)} " - # f"hidden_states={'tuple' if isinstance(hidden_states, tuple) else tuple(hidden_states.shape)} " - # f"residual={None if residual is None else tuple(residual.shape)} " - # f"fuse_input_norm_quant={self.fuse_input_norm_quant}" - # ) if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight @@ -2692,7 +2464,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # wrapper class, so the name of loaded weights are prefixed with "model.". # The vLLM will check the name of the loaded weights to make sure all the # weights are loaded correctly - + # lazy import to avoid circular import issue since model_loader also imports model.. from atom.model_loader.loader import load_model_in_plugin_mode loaded_weights_record = load_model_in_plugin_mode( diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 279845595..cb6d74c75 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -599,4 +599,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_weights_record = load_model_in_plugin_mode( model=self, config=self.atom_config, prefix="model." ) - return loaded_weights_record \ No newline at end of file + return loaded_weights_record diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 4a46f3dfb..a7b42708a 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -187,11 +187,7 @@ class ForwardMetadata: pa_metadata_context_lens: Optional[torch.Tensor] = None pa_metadata_max_qlen: Optional[int] = None pa_metadata_tp_q_head_num: Optional[int] = None - # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) - # prefill_pages_kv_indptr: Optional[torch.Tensor] = None - # prefill_kv_indices: Optional[torch.Tensor] = None - # prefill_kv_last_page_lens: Optional[torch.Tensor] = None - + class ATOMAttnBackendForSgl(AiterAttnBackend): @@ -415,12 +411,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Build pa_metadata for pa_persistent_fwd if self.decode_using_pa_ps: self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) - # return # Early return for non-MLA decode mode else: prefix_lens = forward_batch.extend_prefix_lens if self.use_mla: - # raise NotImplementedError("MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl.") self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, @@ -441,7 +435,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): reduce_final_map = None fp8_prefill_kv_indices = None reduce_partial_map = None - + from sglang.srt.utils import is_gfx95_supported _use_fp8_prefill_attn = ( get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() @@ -478,7 +472,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): fp8_prefill_kv_indices = torch.arange( total_s, device=self.device, dtype=torch.int32 ) - + self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indices, @@ -823,7 +817,7 @@ def init_cuda_graph_state( self.strided_indices = torch.arange( 0, self.max_context_len, self.page_size, device=self.device ) - + if self.use_mla and _sglang_aiter._use_mla_ps_kernel: max_seqlen_qo = 1 ( @@ -854,7 +848,7 @@ def init_cuda_graph_state( max_bs, self.num_kv_head, ) - + self._allocate_pa_metadata_buffers( work_metadata_ptrs_size, work_metadata_ptrs_type, @@ -870,6 +864,81 @@ def init_cuda_graph_state( reduce_partial_map_type, ) + def _init_mla_cuda_graph_metadata(self, bs, req_pool_indices, seq_lens): + """Shared MLA decode metadata setup for CUDA graph capture/replay.""" + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -880,100 +949,20 @@ def init_forward_metadata_capture_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInput], ): - if forward_mode.is_decode_or_idle(): - if self.use_mla: - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - - qo_indptr = self.qo_indptr_[: bs + 1] - qo_indptr[1 : bs + 1] = torch.cumsum( - self.cuda_graph_kv_last_page_len[:bs], dim=0 - ) - kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] - max_q_len = 1 - - work_metadata = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - reduce_partial_map = None - num_kv_splits = None - - if _sglang_aiter._use_mla_ps_kernel: - num_kv_splits = self.max_split_per_batch - - self.make_mla_meta_data( - qo_indptr, - kv_indptr, - kv_last_page_len, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - max_q_len, - fast_mode=_sglang_aiter.fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - ) - - work_metadata = self.work_metadata - work_info_set = self.work_info_set - work_indptr = self.work_indptr - reduce_indptr = self.reduce_indptr - reduce_final_map = self.reduce_final_map - reduce_partial_map = self.reduce_partial_map + if not forward_mode.is_decode_or_idle(): + raise ValueError(f"Invalid mode: {forward_mode=}") - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - None, - None, - None, - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - ) - else: - page_table = self.page_table[:bs, :] - self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) - seq_lens_persistent = self.seq_lens[:bs] - self.forward_metadata = ForwardMetadata( - None, - None, - None, - None, - 1, - None, - page_table, - seq_lens_persistent, - ) - - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) - return + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) else: - raise ValueError(f"Invalid mode: {forward_mode=}") + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, None, None, None, 1, None, page_table, seq_lens_persistent, + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) def init_forward_metadata_replay_cuda_graph( self, @@ -987,105 +976,28 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, ): - if forward_mode.is_decode_or_idle(): - if self.use_mla: - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - - qo_indptr = self.qo_indptr_[: bs + 1] - qo_indptr[1 : bs + 1] = torch.cumsum( - self.cuda_graph_kv_last_page_len[:bs], dim=0 - ) - kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] - max_q_len = 1 - - work_metadata = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - reduce_partial_map = None - num_kv_splits = None - - if _sglang_aiter._use_mla_ps_kernel: - num_kv_splits = self.max_split_per_batch - - self.make_mla_meta_data( - qo_indptr, - kv_indptr, - kv_last_page_len, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, - max_q_len, - fast_mode=_sglang_aiter.fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - ) - - work_metadata = self.work_metadata - work_info_set = self.work_info_set - work_indptr = self.work_indptr - reduce_indptr = self.reduce_indptr - reduce_final_map = self.reduce_final_map - reduce_partial_map = self.reduce_partial_map - - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - None, - None, - None, - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - ) - else: - page_table_persistent = self.page_table - seq_lens_persistent = self.seq_lens - seq_lens_persistent.fill_(0) - page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) - max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 - page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] - page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + if not forward_mode.is_decode_or_idle(): + raise ValueError("Invalid forward mode") - self.forward_metadata = ForwardMetadata( - None, - None, - None, - None, - 1, - None, - page_table_persistent[:bs, :max_seq_pages], - seq_lens_persistent[:bs], - ) - - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) else: - raise ValueError("Invalid forward mode") + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + + self.forward_metadata = ForwardMetadata( + None, None, None, None, 1, None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) def set_kv_buffer_with_layout_shuffle( self, @@ -1202,12 +1114,11 @@ def _forward_extend_mla(self, q, k, v, layer, forward_batch): kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, ) - elif forward_batch.forward_mode.is_target_verify(): - return self._forward_extend_mla_target_verify( - q, layer, K_Buffer, qo_indptr, - ) - elif forward_batch.forward_mode.is_draft_extend(): - return self._forward_extend_mla_draft_extend( + elif ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + return self._forward_extend_mla_speculative( q, layer, K_Buffer, qo_indptr, ) else: @@ -1427,86 +1338,35 @@ def _forward_extend_mla_normal( ) return o - def _forward_extend_mla_target_verify(self, q, layer, K_Buffer, qo_indptr): - """MLA target_verify path (speculative decoding verification).""" - o = q.new_empty( - (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), - dtype=self.input_dtype, - ) - - work_metadata = self.forward_metadata.work_metadata - work_indptr = self.forward_metadata.work_indptr - work_info_set = self.forward_metadata.work_info_set - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - num_kv_splits = self.forward_metadata.num_kv_splits - + def _call_mla_decode_fwd(self, q, k_buffer, o, layer): + """Common mla_decode_fwd invocation shared across decode/extend paths.""" + md = self.forward_metadata mla_decode_fwd( - q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim), - o, - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - sm_scale=layer.scaling, - logit_cap=layer.logit_cap, - work_meta_data=work_metadata, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=layer.k_scale, - kv_scale=layer.k_scale, + q, k_buffer.view(-1, 1, 1, layer.qk_head_dim), o, + md.qo_indptr, md.kv_indptr, md.kv_indices, + md.kv_last_page_len, md.max_q_len, + sm_scale=layer.scaling, logit_cap=layer.logit_cap, + work_meta_data=md.work_metadata, + work_indptr=md.work_indptr, + work_info_set=md.work_info_set, + reduce_indptr=md.reduce_indptr, + reduce_final_map=md.reduce_final_map, + reduce_partial_map=md.reduce_partial_map, + q_scale=layer.k_scale, kv_scale=layer.k_scale, intra_batch_mode=_sglang_aiter.intra_batch_mode, - num_kv_splits=num_kv_splits, + num_kv_splits=md.num_kv_splits, ) - return o - def _forward_extend_mla_draft_extend(self, q, layer, K_Buffer, qo_indptr): - """MLA draft_extend path (speculative decoding draft extension).""" + def _forward_extend_mla_speculative(self, q, layer, K_Buffer, qo_indptr): + """MLA speculative path (target_verify / draft_extend).""" o = q.new_empty( (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), dtype=self.input_dtype, ) - - work_metadata = self.forward_metadata.work_metadata - work_indptr = self.forward_metadata.work_indptr - work_info_set = self.forward_metadata.work_info_set - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - num_kv_splits = self.forward_metadata.num_kv_splits - - mla_decode_fwd( - q, - K_Buffer.view(-1, 1, 1, layer.qk_head_dim), - o, - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - sm_scale=layer.scaling, - logit_cap=layer.logit_cap, - work_meta_data=work_metadata, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=layer.k_scale, - kv_scale=layer.k_scale, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - num_kv_splits=num_kv_splits, - ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) return o - - def forward_decode_pa( + def forward_decode( self, q: torch.Tensor, k: torch.Tensor, @@ -1516,351 +1376,77 @@ def forward_decode_pa( save_kv_cache=True, ): q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) - else: - o = torch.empty_like(q) - - if save_kv_cache: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - self.set_kv_buffer_with_layout_shuffle( - forward_batch.out_cache_loc, - k, - v, - k_buffer, - v_buffer, - layer.k_scale, - layer.v_scale, - self.page_size, - ) + batch_size = q.shape[0] + head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim if self.use_mla: + o = q.new_empty( + (batch_size, layer.tp_q_head_num * head_dim_out), dtype=self.input_dtype, + ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - - work_metadata = self.forward_metadata.work_metadata - work_indptr = self.forward_metadata.work_indptr - work_info_set = self.forward_metadata.work_info_set - - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - - num_kv_splits = self.forward_metadata.num_kv_splits - - mla_decode_fwd( + self._call_mla_decode_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k_buffer.view(-1, 1, 1, layer.qk_head_dim), + k_buffer, o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - sm_scale=layer.scaling, - logit_cap=layer.logit_cap, - work_meta_data=work_metadata, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=layer.k_scale, - kv_scale=layer.k_scale, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - num_kv_splits=num_kv_splits, - ) - - else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - - block_size = self.page_size - num_slots, num_kv_heads, head_size = k_buffer.shape - num_blocks = num_slots // block_size - k_buffer = k_buffer[: num_blocks * block_size].view( - num_blocks, block_size, num_kv_heads, head_size - ) - v_buffer = v_buffer[: num_blocks * block_size].view( - num_blocks, block_size, num_kv_heads, head_size - ) - - x = 16 // k_buffer.element_size() - k_cache_template = torch.empty( - [num_blocks, num_kv_heads, head_size // x, block_size, x], - dtype=k_buffer.dtype, - device="meta", - ) - v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], - dtype=v_buffer.dtype, - device="meta", - ) - new_key_cache = k_buffer.view_as(k_cache_template) - new_value_cache = v_buffer.view_as(v_cache_template) - q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - pa_fwd_asm( - Q=q, - K=new_key_cache, - V=new_value_cache, - block_tables=self.forward_metadata.page_table, - context_lens=self.forward_metadata.kv_lens, - block_tables_stride0=self.forward_metadata.page_table.stride(0), - K_QScale=self.k_scale, - V_QScale=self.v_scale, - out_=o, + layer, ) return o - def forward_decode_pa_ps( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm - # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) - # Use q.shape[0] instead of forward_batch.batch_size to be safe - batch_size = q.shape[0] - head_dim_out = ( - layer.v_head_dim - if layer.qk_head_dim != layer.v_head_dim - else layer.head_dim - ) + # Non-MLA decode paths o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) if save_kv_cache: - if self.use_mla: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) - else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - self.set_kv_buffer_with_layout_shuffle(forward_batch.out_cache_loc, k, v, k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size) - - if self.use_mla: - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - - work_metadata = self.forward_metadata.work_metadata - work_indptr = self.forward_metadata.work_indptr - work_info_set = self.forward_metadata.work_info_set - - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - - num_kv_splits = self.forward_metadata.num_kv_splits - - mla_decode_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k_buffer.view(-1, 1, 1, layer.qk_head_dim), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - sm_scale=layer.scaling, - logit_cap=layer.logit_cap, - work_meta_data=work_metadata, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=layer.k_scale, - kv_scale=layer.k_scale, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - num_kv_splits=num_kv_splits, - ) - else: - k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - num_slots, num_kv_heads, head_size = k_buffer.shape - block_size = self.page_size - num_blocks = num_slots // block_size - k_buffer = k_buffer[: num_blocks * block_size].view( - num_blocks, block_size, num_kv_heads, head_size - ) - v_buffer = v_buffer[: num_blocks * block_size].view( - num_blocks, block_size, num_kv_heads, head_size + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, k, v, + k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size, ) - quant_dtype = dtypes.fp8 - x = 16 // quant_dtype.itemsize - k_cache_template = torch.empty( - [num_blocks, num_kv_heads, head_size // x, block_size, x], - dtype=k_buffer.dtype, - device="meta", - ) - # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] - v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], - dtype=v_buffer.dtype, - device="meta", - ) - new_key_cache = k_buffer.view_as(k_cache_template) - new_value_cache = v_buffer.view_as(v_cache_template) + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + x = 16 // k_buffer.element_size() + new_key_cache = k_buffer.view(num_blocks, num_kv_heads, head_size // x, block_size, x) + new_value_cache = v_buffer.view(num_blocks, num_kv_heads, block_size // x, head_size, x) + if self.decode_using_pa_ps: total_tokens = num_blocks * block_size - k_qscale = self.k_qscale[:, :total_tokens] - v_qscale = self.v_qscale[:, :total_tokens] - - q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) - - assert ( - self.forward_metadata.pa_metadata_qo_indptr is not None - ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_pages_kv_indptr is not None - ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_kv_indices is not None - ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_context_lens is not None - ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" - assert ( - self.forward_metadata.pa_metadata_max_qlen is not None - ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" - - qo_indptr = self.forward_metadata.pa_metadata_qo_indptr - kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr - kv_indices = self.forward_metadata.pa_metadata_kv_indices - context_lens = self.forward_metadata.pa_metadata_context_lens - max_qlen = self.forward_metadata.pa_metadata_max_qlen - - _, _ = pa_persistent_fwd( - Q=q, - K=new_key_cache, - V=new_value_cache, - output=o, - max_qlen=max_qlen, - qo_indptr=qo_indptr, - kv_indptr=kv_indptr, - kv_indices=kv_indices, - context_lens=context_lens, + q_3d = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + pa_persistent_fwd( + Q=q_3d, K=new_key_cache, V=new_value_cache, output=o, + max_qlen=self.forward_metadata.pa_metadata_max_qlen, + qo_indptr=self.forward_metadata.pa_metadata_qo_indptr, + kv_indptr=self.forward_metadata.pa_metadata_pages_kv_indptr, + kv_indices=self.forward_metadata.pa_metadata_kv_indices, + context_lens=self.forward_metadata.pa_metadata_context_lens, work_indptr=self.pa_metadata_buffers["work_indptr"], work_info=self.pa_metadata_buffers["work_info"], reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], - K_QScale=k_qscale, - V_QScale=v_qscale, - softmax_scale=layer.scaling, - mask=1, - ) - return o.view(-1, layer.tp_q_head_num * head_dim_out) - - def forward_decode( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - save_kv_cache=True, - ): - if self.use_mla: - return self._forward_decode_mla(q, k, v, layer, forward_batch, save_kv_cache) - else: - if self.decode_using_pa_ps: - return self.forward_decode_pa_ps( - q, k, v, layer, forward_batch, save_kv_cache - ) - else: - return self.forward_decode_pa(q, k, v, layer, forward_batch, save_kv_cache) - - def _forward_decode_mla(self, q, k, v, layer, forward_batch, save_kv_cache): - q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) - - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty( - (q.shape[0], layer.tp_q_head_num * layer.v_head_dim), - dtype=self.input_dtype, + K_QScale=self.k_qscale[:, :total_tokens], + V_QScale=self.v_qscale[:, :total_tokens], + softmax_scale=layer.scaling, mask=1, ) else: - o = torch.empty_like(q, dtype=self.input_dtype) - - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v + q_3d = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q_3d, K=new_key_cache, V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, V_QScale=self.v_scale, out_=o, ) - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - - work_metadata = self.forward_metadata.work_metadata - work_indptr = self.forward_metadata.work_indptr - work_info_set = self.forward_metadata.work_info_set - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - num_kv_splits = self.forward_metadata.num_kv_splits - - if layer.layer_id == 0: - _q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - _k_view = k_buffer.view(-1, 1, 1, layer.qk_head_dim) - _o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) - # print( - # f"[MLA_DECODE_DBG] layer=0" - # f" q={tuple(_q_view.shape)} q.dtype={_q_view.dtype}" - # f" k_buf={tuple(_k_view.shape)} k_buf.dtype={_k_view.dtype}" - # f" o={tuple(_o_view.shape)} o.dtype={_o_view.dtype}" - # f" qo_indptr={self.forward_metadata.qo_indptr.tolist()}" - # f" kv_indptr={self.forward_metadata.kv_indptr.tolist()}" - # f" kv_indices_len={self.forward_metadata.kv_indices.shape[0]}" - # f" kv_indices_max={self.forward_metadata.kv_indices.max().item()}" - # f" kv_last_page_len={self.forward_metadata.kv_last_page_len.tolist()}" - # f" max_q_len={self.forward_metadata.max_q_len}" - # f" sm_scale={layer.scaling}" - # f" logit_cap={layer.logit_cap}" - # f" k_scale={layer.k_scale}" - # f" num_kv_splits={num_kv_splits}" - # f" page_size={self.page_size}" - # f" work_metadata={tuple(work_metadata.shape) if work_metadata is not None else None}" - # f" work_indptr={tuple(work_indptr.shape) if work_indptr is not None else None}" - # f" work_info_set={tuple(work_info_set.shape) if work_info_set is not None else None}" - # f" reduce_indptr={tuple(reduce_indptr.shape) if reduce_indptr is not None else None} val={reduce_indptr.tolist() if reduce_indptr is not None and reduce_indptr.numel() < 20 else 'big'}" - # f" reduce_final_map={tuple(reduce_final_map.shape) if reduce_final_map is not None else None}" - # f" reduce_partial_map={tuple(reduce_partial_map.shape) if reduce_partial_map is not None else None}" - # f" intra_batch_mode={_sglang_aiter.intra_batch_mode}" - # f" _use_mla_ps_kernel={_sglang_aiter._use_mla_ps_kernel}" - # f" fast_mode={_sglang_aiter.fast_mode}" - # , flush=True, - # ) - - mla_decode_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k_buffer.view(-1, 1, 1, layer.qk_head_dim), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - self.forward_metadata.qo_indptr, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - sm_scale=layer.scaling, - logit_cap=layer.logit_cap, - work_meta_data=work_metadata, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=layer.k_scale, - kv_scale=layer.k_scale, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - num_kv_splits=num_kv_splits, - ) - - return o + return o.view(-1, layer.tp_q_head_num * head_dim_out) From f9e544fa94e44e27fcb4cb65246e4d587494e2e9 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Thu, 26 Mar 2026 04:57:38 +0000 Subject: [PATCH 12/15] clean code: separate sglang mla path Signed-off-by: zhuyuhua-v --- atom/models/deepseek_v2.py | 511 +--------- atom/models/qwen3_moe.py | 3 + .../attention_backend/sgl_attn_backend.py | 928 +++++++----------- atom/plugin/sglang/__init__.py | 0 atom/plugin/sglang/mla.py | 659 +++++++++++++ 5 files changed, 1031 insertions(+), 1070 deletions(-) create mode 100644 atom/plugin/sglang/__init__.py create mode 100644 atom/plugin/sglang/mla.py diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index f5535bc31..6d2bb86ca 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -70,7 +70,6 @@ from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE, _has_module, quark_post_load_weights from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled -from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE from atom.models.utils import ( IntermediateTensors, PPMissingLayer, @@ -86,23 +85,6 @@ from transformers import PretrainedConfig from atom.plugin.prepare import is_sglang -# from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 - -from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context -from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp -from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode -from sglang.srt.configs.model_config import is_deepseek_nsa -from sglang.srt.models.deepseek_common.utils import _use_aiter_gfx95, _use_aiter, _is_gfx95_supported, _is_hip -from sglang.srt.layers.quantization.rocm_mxfp4_utils import batched_gemm_afp4wfp4_pre_quant -from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, -) -from sglang.srt.layers.quantization.fp8_kernel import ( - fp8_dtype, - per_tensor_quant_mla_fp8, - per_token_group_quant_mla_deep_gemm_masked_fp8, -) - logger = logging.getLogger("atom") if use_triton_gemm(): try: @@ -124,33 +106,6 @@ gemm_a16w8_blockscale_preshuffle = None -from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 - -from sglang.srt.utils.custom_op import register_custom_op - -# TODO(yuwei): remove this wrapper after sgl-kernel registers its own fake/meta impl -# Wrap bmm_fp8 as a custom op so torch.compile does not trace into -# torch.cuda.current_blas_handle() (which returns a non-Tensor). -@register_custom_op(mutates_args=["out"]) -def _bmm_fp8_op( - A: torch.Tensor, - B: torch.Tensor, - out: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, -) -> None: - _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) - -def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): - if out is None: - out = torch.empty( - (A.shape[0], A.shape[1], B.shape[2]), - device=A.device, - dtype=dtype, - ) - _bmm_fp8_op(A, B, out, A_scale, B_scale) - return out - ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION @@ -1498,282 +1453,11 @@ def __init__( self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - # for sglang plugin mode - self.use_nsa = is_deepseek_nsa(config) - self.use_deep_gemm_bmm = False - self.alt_stream = None - self.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 - self.w_kc, self.w_vc = None, None - self.w_scale = None - self.w_scale_k = None - self.w_scale_v = None - - def _mla_absorbed_bmm(self, inp, weight, weight_scale, weight_scale_k, out_dim): - """Shared batched matmul for MLA absorbed weights (w_kc / w_vc). - - Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. - inp: (num_tokens, num_heads, in_dim) — token-major - Returns: (num_tokens, num_heads, out_dim) — token-major - """ - if self.use_deep_gemm_bmm: - val, scale, masked_m, expected_m, aligned_m = ( - per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) - ) - out = inp.new_empty((self.num_local_heads, aligned_m, out_dim)) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - (val, scale), (weight, weight_scale_k), out, masked_m, expected_m, - ) - return out[:, :expected_m, :].transpose(0, 1) - - if _is_hip: - if _use_aiter_gfx95 and weight.dtype == torch.uint8: - x = inp.transpose(0, 1) - out = torch.empty( - x.shape[0], x.shape[1], weight.shape[2], - device=x.device, dtype=torch.bfloat16, - ) - batched_gemm_afp4wfp4_pre_quant( - x, weight.transpose(-2, -1), - weight_scale_k.transpose(-2, -1), - torch.bfloat16, out, - ) - return out.transpose(0, 1) - - if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( - get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz - ): - out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( - X=inp, WQ=weight.transpose(-1, -2), - w_scale=weight_scale, group_size=128, - YQ=None, transpose_bm=False, transpose_bm_in=True, - dtype=torch.bfloat16, - ) - return out.transpose(0, 1) - - out = torch.bmm( - inp.to(torch.bfloat16).transpose(0, 1), - weight.to(torch.bfloat16) * weight_scale, - ) - return out.transpose(0, 1) - - # CUDA fp8 path - if weight.dtype == torch.float8_e4m3fn: - val, scale = per_tensor_quant_mla_fp8( - inp.transpose(0, 1), - torch.zeros((1,), dtype=torch.float32, device=inp.device), - ) - out = bmm_fp8(val, weight, scale, weight_scale, torch.bfloat16) - return out.transpose(0, 1) - - # bf16 fallback - return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) - - def _forward_sgl_prepare( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs: dict[str, Any] | None - ) -> torch.Tensor: - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - - forward_batch = model_kwargs.get("forward_batch", None) - zero_allocator = model_kwargs.get("zero_allocator", None) - llama_4_scaling = model_kwargs.get("llama_4_scaling", None) - q_lora = None - topk_indices = None - if self.q_lora_rank is not None: - q, latent_cache = ( - get_attn_tp_context() - .fetch_qkv_latent() - .split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - ) - - if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: - qkv_lora = torch.cat([q, latent_cache], dim=-1) - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] < positions.shape[0]: - raise RuntimeError( - f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, expected {positions.shape[0]}" - ) - qkv_lora = qkv_lora[: positions.shape[0]] - q, latent_cache = torch.split( - qkv_lora, - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - - k_nope = latent_cache[..., : self.kv_lora_rank] - - # overlap qk norm - if self.alt_stream is not None and get_is_capture_mode(): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - q = self.q_a_layernorm(q) - with torch.cuda.stream(self.alt_stream): - k_nope = self.kv_a_layernorm(k_nope) - current_stream.wait_stream(self.alt_stream) - else: - q = self.q_a_layernorm(q) - k_nope = self.kv_a_layernorm(k_nope) - - if self.use_nsa: - if q_lora is None: - q_lora = q - - # overlap q_b_proj and indexer during decode - if ( - self.alt_stream is not None - and get_is_capture_mode() - and forward_batch.forward_mode.is_decode_or_idle() - and q_lora is not None - ): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - with torch.cuda.stream(self.alt_stream): - k_nope = k_nope.unsqueeze(1) - q = self.q_b_proj(q).view( - -1, self.num_local_heads, self.qk_head_dim - ) - topk_indices = self.indexer( - x=hidden_states, q_lora=q_lora, positions=positions, - forward_batch=forward_batch, layer_id=self.layer_num, - ) - current_stream.wait_stream(self.alt_stream) - else: - k_nope = k_nope.unsqueeze(1) - q = self.q_b_proj(q).view(-1, self.num_local_heads, self.qk_head_dim) - if q_lora is not None: - topk_indices = self.indexer( - x=hidden_states, q_lora=q_lora, positions=positions, - forward_batch=forward_batch, layer_id=self.layer_num, - ) - else: - q = self.q_proj(hidden_states).view( - -1, self.num_local_heads, self.qk_head_dim - ) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - k_nope = latent_cache[..., : self.kv_lora_rank] - k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) - - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) - - q_nope_out = self._mla_absorbed_bmm( - q_nope, self.w_kc, self.w_scale, self.w_scale_k, self.kv_lora_rank, - ) - - if self.rotary_emb is not None and not self.use_fused_qk_rope_concat_and_cache_mla: - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - - if nsa_use_prefill_cp(forward_batch): - k_nope, k_pe = self.rebuild_cp_kv_cache( - latent_cache, forward_batch, k_nope, k_pe - ) - - return ( - q_pe, k_pe, q_nope_out, k_nope, - forward_batch, zero_allocator, positions, topk_indices, llama_4_scaling, - ) - - def _forward_sgl_core( - self, - q_pe, k_pe, q_nope_out, k_nope, - forward_batch, zero_allocator, positions, topk_indices, llama_4_scaling, - ): - save_kv_cache = True - - if self.use_fused_qk_rope_concat_and_cache_mla: - cos = self.rotary_emb.cos_cache - sin = self.rotary_emb.sin_cache - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(self.layer_num) - k_scale = self.mla_attn.attn.k_scale - - q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( - q_nope_out, q_pe, k_nope, k_pe, - kv_cache, forward_batch.out_cache_loc, positions, - cos, sin, k_scale, self.rotary_emb.is_neox_style, - q_out_dtype=q_nope_out.dtype, - ) - k = torch.cat([k_nope, k_pe_roped], dim=-1) - save_kv_cache = False - else: - q = torch.cat([q_nope_out, q_pe], dim=-1) - k = torch.cat([k_nope, k_pe], dim=-1) - - if llama_4_scaling is not None: - q = q * llama_4_scaling - - attn_output = self.mla_attn( - q, k, k_nope, - forward_batch=forward_batch, - save_kv_cache=save_kv_cache, - **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), - ) - attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - - # up-proj by w_vc - attn_bmm_output = self._mla_absorbed_bmm( - attn_output, self.w_vc, self.w_scale, self.w_scale_v, self.v_head_dim, - ).flatten(1, 2) - - return self.o_proj(attn_bmm_output) - - def prepare_qkv_latent(self, hidden_states: torch.Tensor, forward_batch): - assert self.q_lora_rank is not None - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - qkv_lora = self.fused_qkv_a_proj(hidden_states, hidden_states_scale) + # sglang plugin mode attributes (lazily initialised) + if is_sglang(): + from atom.plugin.sglang.mla import init_sgl_attrs - # Fallback: when communicator does not enable input_scattered gather, - # force qkv latent token dimension to align with positions. - expected_tokens = 0 - if hasattr(forward_batch, "positions") and forward_batch.positions is not None: - expected_tokens = int(forward_batch.positions.shape[0]) - if expected_tokens <= 0: - expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) - - if ( - expected_tokens > 0 - and qkv_lora.shape[0] != expected_tokens - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] > expected_tokens: - qkv_lora = qkv_lora[:expected_tokens] - elif qkv_lora.shape[0] < expected_tokens: - raise RuntimeError( - f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {expected_tokens}" - ) - return qkv_lora - - def forward_sgl_plugin_mode( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs: dict[str, Any] | None - ) -> torch.Tensor: - from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") - - attn_tp_context = get_attn_tp_context() - with attn_tp_context.maybe_input_scattered(forward_batch): - if self.q_lora_rank is not None: - attn_tp_context.set_attn_inputs( - AttentionInputs( - hidden_states, forward_batch, self.prepare_qkv_latent, - ) - ) - prepared = self._forward_sgl_prepare(positions, hidden_states, **model_kwargs) - return self._forward_sgl_core(*prepared) + init_sgl_attrs(self, config) def forward_common( self, @@ -1876,191 +1560,16 @@ def forward( **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: if is_sglang(): - attn_output = self.forward_sgl_plugin_mode(positions, hidden_states, **model_kwargs) - else: - attn_output = self.forward_common(positions, hidden_states, **model_kwargs) - return attn_output + from atom.plugin.sglang.mla import forward_sgl_plugin_mode + return forward_sgl_plugin_mode(self, positions, hidden_states, **model_kwargs) + return self.forward_common(positions, hidden_states, **model_kwargs) def process_weights_after_loading(self) -> None: # only for sglang plugin mode if not is_sglang(): return - self._process_mla_kv_b_proj_after_loading_sgl() - - def _process_mla_kv_b_proj_after_loading_sgl(self) -> None: - # lazy imports: only needed for sglang plugin path - from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz - from sglang.srt.layers.quantization.fp8_utils import ( - block_quant_dequant, - block_quant_to_tensor_quant, - channel_quant_to_tensor_quant, - inverse_transform_scale_ue8m0, - ) - from sglang.srt.layers.quantization.int8_utils import ( - block_dequant as int8_block_dequant, - ) - from sglang.srt.layers.deep_gemm_wrapper import ( - ENABLE_JIT_DEEPGEMM, - DEEPGEMM_BLACKWELL, - ) - from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 - from sglang.srt.models.deepseek_common.utils import ( - _is_cpu, - _is_cpu_amx_available, - _is_cuda, - _is_fp8_fnuz, - _is_hip, - _is_npu, - _use_aiter_gfx95, - awq_dequantize_func, - ) - from sglang.srt.utils import bind_or_assign, get_bool_env_var - - # read kv_b_proj weight (awq compatible) - if hasattr(self.kv_b_proj, "qweight"): - awq_dequantize_f = awq_dequantize_func() - if awq_dequantize_f is None: - raise ValueError("AWQ dequantize function is not supported for current device") - w = awq_dequantize_f( - self.kv_b_proj.qweight, - self.kv_b_proj.scales, - self.kv_b_proj.qzeros, - ).T - else: - w = self.kv_b_proj.weight - - # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes - # into them (weight_loader_process view-casts a detached copy, leaving - # the nn.Parameter as fnuz). At this point LinearBase's - # process_weights_after_loading hasn't run yet (parent module iterates - # before child in named_modules). View-cast back to fn so the - # normalize path works correctly. - if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: - w = w.view(torch.float8_e4m3fn) - - use_deep_gemm_bmm = False - block_scale = None - weight_block_size = None - - # Derive weight_block_size from ATOM's quant_type system - from aiter import QuantType as _AiterQuantType - _atom_qt = getattr(self.kv_b_proj, "quant_type", None) - if _atom_qt == _AiterQuantType.per_1x128: - weight_block_size = [128, 128] - elif _atom_qt == _AiterQuantType.per_1x32: - weight_block_size = [1, 32] - - # fp8 path - if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): - if weight_block_size is not None: - assert hasattr(self.kv_b_proj, "weight_scale_inv") or hasattr(self.kv_b_proj, "weight_scale") - weight_scale = ( - self.kv_b_proj.weight_scale - if hasattr(self.kv_b_proj, "weight_scale") - else self.kv_b_proj.weight_scale_inv - ) - - if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=weight_scale, - input_scale=None, - ) - else: - weight = w - - if ( - should_deepgemm_weight_requant_ue8m0( - weight_block_size=weight_block_size - ) - and getattr(weight_scale, "format_ue8m0", False) - ): - weight_scale = inverse_transform_scale_ue8m0(weight_scale, mn=weight.shape[-2]) - - if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: - if ( - ENABLE_JIT_DEEPGEMM - and not DEEPGEMM_BLACKWELL - and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") - ): - block_scale = weight_scale - use_deep_gemm_bmm = True - else: - w = block_quant_dequant( - weight, - weight_scale, - weight_block_size, - torch.bfloat16, - ) - else: - w, scale = block_quant_to_tensor_quant(weight, weight_scale, weight_block_size) - self.w_scale = scale - else: - if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self.kv_b_proj.weight_scale, - input_scale=None, - ) - else: - weight = w - weight_scale = self.kv_b_proj.weight_scale - - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - self.w_scale = scale - - # int8 path - if w.dtype == torch.int8: - if weight_block_size is not None: - assert hasattr(self.kv_b_proj, "weight_scale_inv") - w = int8_block_dequant( - w, - self.kv_b_proj.weight_scale_inv, - weight_block_size, - ).to(torch.bfloat16) - else: - w = w.to(torch.bfloat16) * self.kv_b_proj.weight_scale.to(torch.bfloat16) - - # split to kc/vc - w_kc, w_vc = w.unflatten( - 0, (-1, self.qk_nope_head_dim + self.v_head_dim) - ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) - - # quark fp4 special path (ATOM utility) - quant_method = getattr(self.kv_b_proj, "quant_method", None) - quant_config = getattr(quant_method, "quant_config", None) - if _use_aiter_gfx95 and quant_config is not None and quant_config.get_name() == "quark": - w_kc, self.w_scale_k, w_vc, self.w_scale_v = quark_post_load_weights(self, w, "mxfp4") - - if not use_deep_gemm_bmm: - self.w_kc = bind_or_assign( - self.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) - ) - w_vc = w_vc.contiguous().transpose(1, 2) - if _is_npu: - w_vc = w_vc.contiguous() - self.w_vc = bind_or_assign(self.w_vc, w_vc) - - if hasattr(self.kv_b_proj, "weight_scale") and self.w_scale is None: - self.w_scale = bind_or_assign(self.w_scale, self.kv_b_proj.weight_scale) - if _is_hip: - self.w_scale *= 2.0 - - if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: - self.w_kc = self.w_kc.to(torch.bfloat16) * self.w_scale - self.w_vc = self.w_vc.to(torch.bfloat16) * self.w_scale - else: - num_tiles_k = self.qk_nope_head_dim // weight_block_size[1] - num_tiles_n = self.v_head_dim // weight_block_size[0] - ws_kc, ws_vc = block_scale.unflatten(0, (-1, (num_tiles_k + num_tiles_n))).split( - [num_tiles_k, num_tiles_n], dim=1 - ) - - self.w_scale_k = bind_or_assign(self.w_scale_k, ws_kc.transpose(1, 2).contiguous()) - self.w_scale_v = bind_or_assign(self.w_scale_v, ws_vc.contiguous()) - self.w_kc = bind_or_assign(self.w_kc, w_kc.transpose(1, 2).contiguous()) - self.w_vc = bind_or_assign(self.w_vc, w_vc.contiguous()) - self.use_deep_gemm_bmm = True + from atom.plugin.sglang.mla import process_mla_kv_b_proj_after_loading + process_mla_kv_b_proj_after_loading(self) class DeepseekV2DecoderLayer(nn.Module): @@ -2262,6 +1771,7 @@ def __init__( prefix: str = "", layer_type: type[nn.Module] = DeepseekV2DecoderLayer, ): + # logger.info(f"atom call DeepseekV2Model") super().__init__() config = atom_config.hf_config @@ -2415,6 +1925,7 @@ def __init__( if is_sglang(): from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.layers.communicator import get_attn_tp_context get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index cb6d74c75..421a68c78 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -468,6 +468,9 @@ def forward( inputs_embeds: torch.Tensor | None = None, **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + # import logging + # logger = logging.getLogger("atom.models.qwen3_moe") + # logger.info(f"atom call Qwen3MoeModel") if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index a7b42708a..213c9db38 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -271,251 +271,210 @@ def __init__( def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" + if forward_batch.forward_mode.is_decode_or_idle(): + self._init_forward_metadata_decode(forward_batch) + else: + self._init_forward_metadata_extend(forward_batch) + self._fixup_page_table(forward_batch) + + def _init_forward_metadata_decode(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size - kv_indptr = self.kv_indptr spec_info = forward_batch.spec_info - qo_indptr = None - kv_last_page_len = None - max_q_len = None - page_table = None - if forward_batch.forward_mode.is_decode_or_idle(): - if spec_info is None: - kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) - else: - kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - bs = kv_indptr.shape[0] - 1 - - if self.use_mla: - qo_indptr = self.qo_indptr_[: bs + 1] - qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) - kv_last_page_len = self.kv_last_page_len[:bs] - max_q_len = 1 - - work_metadata = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - reduce_partial_map = None - num_kv_splits = None - - if _sglang_aiter._use_mla_ps_kernel: - ( - work_metadata, - work_indptr, - work_info_set, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) - - num_kv_splits = self.max_split_per_batch - - self.make_mla_meta_data( - qo_indptr, - kv_indptr, - kv_last_page_len, - work_metadata, - work_info_set, - work_indptr, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - max_q_len, - fast_mode=_sglang_aiter.fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - ) + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - None, # max_kv_len - None, # page_table - None, # kv_lens - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - ) + if self.use_mla: + self._init_decode_mla(bs, kv_indptr, kv_indices) + else: + self._init_decode_mha(bs, kv_indptr, kv_indices, forward_batch) - else: - if self.decode_using_pa_ps: - # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction - seq_lens_cpu = forward_batch.seq_lens_cpu - if seq_lens_cpu is None: - seq_lens_cpu = forward_batch.seq_lens.cpu() - - # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) - page_table_persistent = self.page_table - seq_lens_persistent = self.seq_lens - seq_lens_persistent.fill_(0) - page_table_persistent.fill_(0) - seq_lens_persistent[:bs].copy_( - forward_batch.seq_lens, non_blocking=True - ) - max_seq_pages = ( - seq_lens_cpu.max().item() + self.page_size - 1 - ) // self.page_size + 1 - page_table = self.req_to_token[ - forward_batch.req_pool_indices[:, None], - self.strided_indices[:max_seq_pages][None, :], - ] - page_table_persistent[:bs, :max_seq_pages].copy_( - page_table // self.page_size, non_blocking=True - ) - else: - page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : - ] + def _init_decode_mla(self, bs, kv_indptr, kv_indices): + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) + kv_last_page_len = self.kv_last_page_len[:bs] + max_q_len = 1 - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - None, # qo_indptr not used in non-MLA mode - None, # kv_last_page_len not used in non-MLA mode - 1, # max_q_len = 1 for decode mode - None, - ( - page_table_persistent[:bs, :max_seq_pages] - if self.decode_using_pa_ps - else page_table - ), - ( - seq_lens_persistent[:bs] - if self.decode_using_pa_ps - else forward_batch.seq_lens - ), - ) + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, kv_indptr, kv_last_page_len, + work_metadata, work_info_set, work_indptr, + reduce_indptr, reduce_final_map, reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, qo_indptr, kv_last_page_len, + max_q_len, None, None, None, + work_metadata=work_metadata, work_info_set=work_info_set, + work_indptr=work_indptr, reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) - # Build pa_metadata for pa_persistent_fwd - if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + def _init_decode_mha(self, bs, kv_indptr, kv_indices, forward_batch): + if self.decode_using_pa_ps: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, None, None, 1, None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) else: - prefix_lens = forward_batch.extend_prefix_lens - - if self.use_mla: - self.mla_indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - forward_batch.extend_seq_lens, - forward_batch.extend_seq_lens.max().item(), - forward_batch.seq_lens.max().item(), - spec_info=None - ) + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, None, None, 1, None, + page_table, forward_batch.seq_lens, + ) - max_q_len = self.mla_indices_updater_prefill.max_q_len - qo_indptr = self.mla_indices_updater_prefill.qo_indptr + def _init_forward_metadata_extend(self, forward_batch: ForwardBatch): + bs = forward_batch.batch_size - work_metadata = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - fp8_prefill_kv_indices = None - reduce_partial_map = None + if self.use_mla: + self._init_extend_mla(bs, forward_batch) + else: + self._init_extend_mha(bs, forward_batch) + + def _init_extend_mla(self, bs, forward_batch): + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_seq_lens, + forward_batch.extend_seq_lens.max().item(), + forward_batch.seq_lens.max().item(), + spec_info=None, + ) - from sglang.srt.utils import is_gfx95_supported - _use_fp8_prefill_attn = ( - get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() - ) - if _use_fp8_prefill_attn: - tile_q = 256 - qlen_granularity = tile_q // (self.num_head // self.num_kv_head) - ( - work_metadata, - work_indptr, - work_info_set, - reduce_indptr, - reduce_final_map, - reduce_partial_map - ) = self.make_mla_prefill_ps_meta_data_buffer( - bs, max_q_len, qlen_granularity - ) + max_q_len = self.mla_indices_updater_prefill.max_q_len + qo_indptr = self.mla_indices_updater_prefill.qo_indptr + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + fp8_prefill_kv_indices = None - self.make_mla_prefill_ps_meta_data( - qo_indptr, - qo_indptr, - forward_batch.seq_lens, - work_metadata, - work_indptr, - work_info_set, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - is_causal=True, - ) + from sglang.srt.utils import is_gfx95_supported + _use_fp8_prefill_attn = ( + get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") + and is_gfx95_supported() + ) + if _use_fp8_prefill_attn: + tile_q = 256 + qlen_granularity = tile_q // (self.num_head // self.num_kv_head) + ( + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + ) = self.make_mla_prefill_ps_meta_data_buffer( + bs, max_q_len, qlen_granularity + ) + self.make_mla_prefill_ps_meta_data( + qo_indptr, qo_indptr, forward_batch.seq_lens, + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + is_causal=True, + ) + total_s = int(forward_batch.extend_seq_lens.sum()) + fp8_prefill_kv_indices = torch.arange( + total_s, device=self.device, dtype=torch.int32 + ) - total_s = int(forward_batch.extend_seq_lens.sum()) - fp8_prefill_kv_indices = torch.arange( - total_s, device=self.device, dtype=torch.int32 - ) + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + self.mla_indices_updater_prefill.max_kv_len, + None, None, + work_metadata=work_metadata, work_info_set=work_info_set, + work_indptr=work_indptr, reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + fp8_prefill_kv_indices=fp8_prefill_kv_indices, + ) - self.forward_metadata = ForwardMetadata( - self.mla_indices_updater_prefill.kv_indptr, - self.mla_indices_updater_prefill.kv_indices, - qo_indptr, - self.kv_last_page_len[:bs], - max_q_len, - self.mla_indices_updater_prefill.max_kv_len, - None, - None, - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - fp8_prefill_kv_indices=fp8_prefill_kv_indices, - ) - else: - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens, - encoder_lens=forward_batch.encoder_lens, - spec_info=None, - ) - # Get page_table for mha_batch_prefill_func - page_table = forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, : - ] - self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, - self.qo_indptr[ - : bs + 1 - ], # qo_indptr is set by indices_updater_prefill - None, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, - None, - forward_batch.seq_lens, - ) + def _init_extend_mha(self, bs, forward_batch): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[: bs + 1], + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + def _fixup_page_table(self, forward_batch: ForwardBatch): + """Post-process page_table for non-MLA extend mode.""" if ( forward_batch.forward_mode.is_extend() and not self.use_mla @@ -546,109 +505,38 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): // self.page_size ) - def _allocate_pa_metadata_buffers( - self, - work_metadata_ptrs_size, - work_metadata_ptrs_type, - work_indptr_size, - work_indptr_type, - work_info_size, - work_info_type, - reduce_indptr_size, - reduce_indptr_type, - reduce_final_map_size, - reduce_final_map_type, - reduce_partial_map_size, - reduce_partial_map_type, - ): - """Allocate or reuse pa_metadata buffers.""" + def _ensure_buffer(self, name, size, dtype, zero=True): + """Allocate or reuse a pa_metadata buffer, growing if needed.""" if self.pa_metadata_buffers is None: self.pa_metadata_buffers = {} + size_val = size[0] if isinstance(size, (tuple, list)) else size + buf = self.pa_metadata_buffers.get(name) + needs_alloc = ( + buf is None + or buf.shape[0] < size_val + or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) + ) + if needs_alloc: + factory = torch.zeros if zero else torch.empty + self.pa_metadata_buffers[name] = factory(size, dtype=dtype, device=self.device) + elif zero: + self.pa_metadata_buffers[name].zero_() - def _get_size_val(size): - return size[0] if isinstance(size, tuple) else size - - # Allocate work_metadata_ptrs - size_val = _get_size_val(work_metadata_ptrs_size) - if ( - "work_metadata_ptrs" not in self.pa_metadata_buffers - or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val - ): - self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( - work_metadata_ptrs_size, - dtype=work_metadata_ptrs_type, - device=self.device, - ) - - # Allocate work_indptr - size_val = _get_size_val(work_indptr_size) - if ( - "work_indptr" not in self.pa_metadata_buffers - or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val - ): - self.pa_metadata_buffers["work_indptr"] = torch.zeros( - work_indptr_size, dtype=work_indptr_type, device=self.device - ) - else: - self.pa_metadata_buffers["work_indptr"].zero_() - - # Allocate work_info - size_val = _get_size_val(work_info_size) - if ( - "work_info" not in self.pa_metadata_buffers - or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) - or self.pa_metadata_buffers["work_info"].shape[0] < size_val - ): - self.pa_metadata_buffers["work_info"] = torch.zeros( - work_info_size, dtype=work_info_type, device=self.device - ) - else: - self.pa_metadata_buffers["work_info"].zero_() - - # Allocate reduce_indptr - size_val = _get_size_val(reduce_indptr_size) - if ( - "reduce_indptr" not in self.pa_metadata_buffers - or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val - ): - self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( - reduce_indptr_size, dtype=reduce_indptr_type, device=self.device - ) - else: - self.pa_metadata_buffers["reduce_indptr"].zero_() - - # Allocate reduce_final_map - size_val = _get_size_val(reduce_final_map_size) - if ( - "reduce_final_map" not in self.pa_metadata_buffers - or len(self.pa_metadata_buffers["reduce_final_map"].shape) - < len(reduce_final_map_size) - or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val - ): - self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( - reduce_final_map_size, dtype=reduce_final_map_type, device=self.device - ) - else: - self.pa_metadata_buffers["reduce_final_map"].zero_() + def _allocate_pa_metadata_buffers(self, buffer_specs): + """Allocate or reuse pa_metadata buffers. - # Allocate reduce_partial_map - reduce_partial_map_size_val = ( - reduce_partial_map_size - if isinstance(reduce_partial_map_size, int) - else reduce_partial_map_size[0] - ) - if ( - "reduce_partial_map" not in self.pa_metadata_buffers - or self.pa_metadata_buffers["reduce_partial_map"].shape[0] - < reduce_partial_map_size_val - ): - self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( - reduce_partial_map_size, - dtype=reduce_partial_map_type, - device=self.device, - ) - else: - self.pa_metadata_buffers["reduce_partial_map"].zero_() + Args: + buffer_specs: sequence of ((size, dtype), ...) tuples from get_pa_metadata_info_v1, + in order: work_metadata_ptrs, work_indptr, work_info, + reduce_indptr, reduce_final_map, reduce_partial_map. + """ + names = [ + "work_metadata_ptrs", "work_indptr", "work_info", + "reduce_indptr", "reduce_final_map", "reduce_partial_map", + ] + zero_flags = [False, True, True, True, True, True] + for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): + self._ensure_buffer(name, size, dtype, zero=zero) def _build_pa_metadata_for_decode( self, @@ -670,33 +558,8 @@ def _build_pa_metadata_for_decode( if tp_q_head_num is None: tp_q_head_num = self.num_head - # kv_dtype_for_metadata = dtypes.fp8 - ( - (work_metadata_ptrs_size, work_metadata_ptrs_type), - (work_indptr_size, work_indptr_type), - (work_info_size, work_info_type), - (reduce_indptr_size, reduce_indptr_type), - (reduce_final_map_size, reduce_final_map_type), - (reduce_partial_map_size, reduce_partial_map_type), - ) = get_pa_metadata_info_v1( - batch_size, - self.num_kv_head, - ) - # Allocate metadata buffers with reuse optimization for multi-layer forward passes - self._allocate_pa_metadata_buffers( - work_metadata_ptrs_size, - work_metadata_ptrs_type, - work_indptr_size, - work_indptr_type, - work_info_size, - work_info_type, - reduce_indptr_size, - reduce_indptr_type, - reduce_final_map_size, - reduce_final_map_type, - reduce_partial_map_size, - reduce_partial_map_type, - ) + buffer_specs = get_pa_metadata_info_v1(batch_size, self.num_kv_head) + self._allocate_pa_metadata_buffers(buffer_specs) qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) @@ -837,32 +700,8 @@ def init_cuda_graph_state( self.reduce_partial_map = None if self.decode_using_pa_ps and not self.use_mla: - ( - (work_metadata_ptrs_size, work_metadata_ptrs_type), - (work_indptr_size, work_indptr_type), - (work_info_size, work_info_type), - (reduce_indptr_size, reduce_indptr_type), - (reduce_final_map_size, reduce_final_map_type), - (reduce_partial_map_size, reduce_partial_map_type), - ) = get_pa_metadata_info_v1( - max_bs, - self.num_kv_head, - ) - - self._allocate_pa_metadata_buffers( - work_metadata_ptrs_size, - work_metadata_ptrs_type, - work_indptr_size, - work_indptr_type, - work_info_size, - work_info_type, - reduce_indptr_size, - reduce_indptr_type, - reduce_final_map_size, - reduce_final_map_type, - reduce_partial_map_size, - reduce_partial_map_type, - ) + buffer_specs = get_pa_metadata_info_v1(max_bs, self.num_kv_head) + self._allocate_pa_metadata_buffers(buffer_specs) def _init_mla_cuda_graph_metadata(self, bs, req_pool_indices, seq_lens): """Shared MLA decode metadata setup for CUDA graph capture/replay.""" @@ -1132,211 +971,160 @@ def _forward_extend_mla_normal( kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, ): - """Normal MLA extend (not target_verify, not draft_extend). - - Three sub-paths mirroring sglang aiter_backend: - 1) No prefix -> fp8 prefill kernel (mla_prefill_ps_asm_fwd) or flash_attn fallback - 2) Has prefix, absorbed weights differ -> decompress via kv_b_proj + flash_attn - 3) Has prefix, qk_head_dim matches -> mla_prefill_fwd kernel - """ + """Normal MLA extend (not target_verify, not draft_extend).""" extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) if kv_indices.shape[0] == 0 or extend_no_prefix: - # --- Sub-path 1: no prefix, pure prefill --- - use_fp8_prefill = ( - self.forward_metadata.fp8_prefill_kv_indices is not None + return self._extend_mla_no_prefix( + q, k, v, layer, kv_lora_rank, qk_rope_head_dim, + max_q_len, qo_indptr, + ) + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + return self._extend_mla_decompress_prefix( + q, layer, forward_batch, K_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ) + else: + return self._extend_mla_absorbed_prefix( + q, layer, K_Buffer, kv_indptr, kv_indices, qo_indptr, ) - if use_fp8_prefill: - total_s = q.shape[0] - nhead = layer.tp_q_head_num - v_head_dim = layer.v_head_dim - - if q.dtype != dtypes.fp8: - q = q.to(dtypes.fp8) - if k.dtype != dtypes.fp8: - k = k.to(dtypes.fp8) - if v.dtype != dtypes.fp8: - v = v.to(dtypes.fp8) - one_scale = torch.ones( - (), dtype=torch.float32, device=q.device - ) - - kv_indptr_asm = qo_indptr - kv_indices_asm = self.forward_metadata.fp8_prefill_kv_indices - - tile_q = 256 - reduce_indptr = self.forward_metadata.reduce_indptr - reduce_final_map = self.forward_metadata.reduce_final_map - reduce_partial_map = self.forward_metadata.reduce_partial_map - logits = torch.empty( - (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), - dtype=torch.float32, - device=q.device, - ) - attn_lse = torch.empty( - (reduce_partial_map.size(0) * tile_q, nhead), - dtype=torch.float32, - device=q.device, - ) - final_lse = torch.empty( - (total_s, nhead), - dtype=torch.float32, - device=q.device, - ) - output = q.new_empty( - (total_s, nhead, v_head_dim), - dtype=self.input_dtype, - ) + def _extend_mla_no_prefix( + self, q, k, v, layer, kv_lora_rank, qk_rope_head_dim, + max_q_len, qo_indptr, + ): + """No-prefix prefill: FP8 kernel, mla_prefill_fwd, or flash_attn fallback.""" + if self.forward_metadata.fp8_prefill_kv_indices is not None: + return self._extend_mla_fp8_prefill(q, k, v, layer, max_q_len, qo_indptr) - mla_prefill_ps_asm_fwd( - q, - k, - v, - qo_indptr, - kv_indptr_asm, - kv_indices_asm, - self.forward_metadata.work_indptr, - self.forward_metadata.work_info_set, - max_q_len, - layer.scaling, - True, - logits, - attn_lse, - output, - one_scale, - one_scale, - one_scale, - ) - mla_reduce_v1( - logits, - attn_lse, - reduce_indptr, - reduce_final_map, - reduce_partial_map, - tile_q, - output, - final_lse, - ) - elif layer.qk_head_dim == (kv_lora_rank + qk_rope_head_dim) and mla_prefill_fwd is not None: - # Absorbed MLA: head_dim (576) exceeds CK limit (256), - # use mla_prefill_fwd which natively supports large MLA head dims. - # For no-prefix, use input k (bfloat16) directly instead of K_Buffer - # (which may be FP8). mla_prefill_fwd doesn't support FP8 KV. - if layer.qk_head_dim != layer.v_head_dim: - output = q.new_empty( - (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) - ) - else: - output = torch.empty_like(q) - total_s = q.shape[0] - temp_kv_indices = torch.arange( - total_s, device=q.device, dtype=torch.int32 - ) - mla_prefill_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k.view(-1, 1, 1, layer.qk_head_dim), - output.view(-1, layer.tp_q_head_num, layer.v_head_dim), - qo_indptr, - qo_indptr, - temp_kv_indices, - self.forward_metadata.kv_last_page_len, - max_q_len, - layer.scaling, - layer.logit_cap, - ) + if layer.qk_head_dim == (kv_lora_rank + qk_rope_head_dim) and mla_prefill_fwd is not None: + # Absorbed MLA: head_dim (576) exceeds CK limit (256), + # use mla_prefill_fwd which natively supports large MLA head dims. + if layer.qk_head_dim != layer.v_head_dim: + output = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: - output = flash_attn_varlen_func( - q, - k, - v, - qo_indptr, - qo_indptr, - max_q_len, - max_q_len, - softmax_scale=layer.scaling, - causal=True, - ) - return output - - elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): - # --- Sub-path 2: has prefix, need kv_b_proj decompress --- - K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) - kvc, k_pe = torch.split( - K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 + output = torch.empty_like(q) + total_s = q.shape[0] + temp_kv_indices = torch.arange(total_s, device=q.device, dtype=torch.int32) + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.view(-1, 1, 1, layer.qk_head_dim), + output.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, qo_indptr, temp_kv_indices, + self.forward_metadata.kv_last_page_len, + max_q_len, layer.scaling, layer.logit_cap, ) + return output - if self.kv_cache_dtype == dtypes.fp8: - dtype = q.dtype - kvc = kvc.to(dtype) - k_pe = k_pe.to(dtype) + return flash_attn_varlen_func( + q, k, v, qo_indptr, qo_indptr, max_q_len, max_q_len, + softmax_scale=layer.scaling, causal=True, + ) - kvprefix = layer.kv_b_proj(kvc.contiguous())[0] - kvprefix = kvprefix.view( - -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim - ) - k_prefix, v_prefix = torch.split( - kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 - ) - k_prefix = torch.cat( - [ - k_prefix, - torch.broadcast_to( - k_pe, - (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), - ), - ], - dim=-1, - ) + def _extend_mla_fp8_prefill(self, q, k, v, layer, max_q_len, qo_indptr): + """FP8 prefill path using mla_prefill_ps_asm_fwd + mla_reduce_v1.""" + total_s = q.shape[0] + nhead = layer.tp_q_head_num + v_head_dim = layer.v_head_dim + md = self.forward_metadata - assert ( - forward_batch.extend_prefix_lens.shape - == forward_batch.extend_seq_lens.shape - ) + if q.dtype != dtypes.fp8: + q = q.to(dtypes.fp8) + if k.dtype != dtypes.fp8: + k = k.to(dtypes.fp8) + if v.dtype != dtypes.fp8: + v = v.to(dtypes.fp8) + one_scale = torch.ones((), dtype=torch.float32, device=q.device) + + tile_q = 256 + logits = torch.empty( + (md.reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), + dtype=torch.float32, device=q.device, + ) + attn_lse = torch.empty( + (md.reduce_partial_map.size(0) * tile_q, nhead), + dtype=torch.float32, device=q.device, + ) + final_lse = torch.empty((total_s, nhead), dtype=torch.float32, device=q.device) + output = q.new_empty((total_s, nhead, v_head_dim), dtype=self.input_dtype) + + mla_prefill_ps_asm_fwd( + q, k, v, qo_indptr, qo_indptr, + md.fp8_prefill_kv_indices, md.work_indptr, md.work_info_set, + max_q_len, layer.scaling, True, + logits, attn_lse, output, one_scale, one_scale, one_scale, + ) + mla_reduce_v1( + logits, attn_lse, md.reduce_indptr, md.reduce_final_map, + md.reduce_partial_map, tile_q, output, final_lse, + ) + return output - o = flash_attn_varlen_func( - q, + def _extend_mla_decompress_prefix( + self, q, layer, forward_batch, K_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ): + """Has prefix, absorbed weights differ: decompress via kv_b_proj + flash_attn.""" + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split(K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1) + + if self.kv_cache_dtype == dtypes.fp8: + dtype = q.dtype + kvc = kvc.to(dtype) + k_pe = k_pe.to(dtype) + + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ k_prefix, - v_prefix, - qo_indptr, - kv_indptr, - max_q_len, - max_kv_len, - softmax_scale=layer.scaling, - causal=True, - ) - return o + torch.broadcast_to( + k_pe, (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) - else: - # --- Sub-path 3: has prefix, qk_head_dim == kv_lora_rank + qk_rope_head_dim --- - # Gather needed KV entries and cast to bf16 (K_Buffer may be FP8) - k_selected = torch.index_select(K_Buffer, 0, kv_indices) - if k_selected.dtype != q.dtype: - k_selected = k_selected.to(q.dtype) - compact_kv_indices = torch.arange( - k_selected.shape[0], device=q.device, dtype=torch.int32 - ) + assert forward_batch.extend_prefix_lens.shape == forward_batch.extend_seq_lens.shape - if layer.qk_head_dim != layer.v_head_dim: - o = q.new_empty( - (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) - ) - else: - o = torch.empty_like(q) + return flash_attn_varlen_func( + q, k_prefix, v_prefix, qo_indptr, kv_indptr, + max_q_len, max_kv_len, softmax_scale=layer.scaling, causal=True, + ) - mla_prefill_fwd( - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - k_selected.view(-1, 1, 1, layer.qk_head_dim), - o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - qo_indptr, - kv_indptr, - compact_kv_indices, - self.forward_metadata.kv_last_page_len, - self.forward_metadata.max_q_len, - layer.scaling, - layer.logit_cap, - ) - return o + def _extend_mla_absorbed_prefix( + self, q, layer, K_Buffer, kv_indptr, kv_indices, qo_indptr, + ): + """Has prefix, qk_head_dim == kv_lora_rank + qk_rope_head_dim: mla_prefill_fwd.""" + k_selected = torch.index_select(K_Buffer, 0, kv_indices) + if k_selected.dtype != q.dtype: + k_selected = k_selected.to(q.dtype) + compact_kv_indices = torch.arange( + k_selected.shape[0], device=q.device, dtype=torch.int32 + ) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_selected.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, kv_indptr, compact_kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, layer.logit_cap, + ) + return o def _call_mla_decode_fwd(self, q, k_buffer, o, layer): """Common mla_decode_fwd invocation shared across decode/extend paths.""" diff --git a/atom/plugin/sglang/__init__.py b/atom/plugin/sglang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/sglang/mla.py b/atom/plugin/sglang/mla.py new file mode 100644 index 000000000..e7d52a9c7 --- /dev/null +++ b/atom/plugin/sglang/mla.py @@ -0,0 +1,659 @@ +"""Sglang-specific MLA forward and weight processing for DeepseekV2/V3. + +This module is lazily imported from deepseek_v2.py only when running in sglang +plugin mode (``is_sglang() == True``). Keeping all sglang-dependent imports +here avoids crashing when sglang is not installed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple, Optional + +import torch +from aiter.dist.parallel_state import get_tensor_model_parallel_world_size, get_tp_group +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + +# sglang imports +from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context +from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.models.deepseek_common.utils import ( + _use_aiter_gfx95, + _is_hip, + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_npu, + _use_aiter_gfx95, + awq_dequantize_func, +) +from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( + batched_gemm_afp4wfp4_pre_quant, +) +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) +from sglang.srt.utils import bind_or_assign, get_bool_env_var + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + +# --------------------------------------------------------------------------- +# bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) +# --------------------------------------------------------------------------- +if _is_cuda: + from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 + from sglang.srt.utils.custom_op import register_custom_op + + @register_custom_op(mutates_args=["out"]) + def _bmm_fp8_op( + A: torch.Tensor, + B: torch.Tensor, + out: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + ) -> None: + _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _bmm_fp8_op(A, B, out, A_scale, B_scale) + return out + +else: + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") + + +# --------------------------------------------------------------------------- +# NamedTuple for prepare → core data flow +# --------------------------------------------------------------------------- +class SglPrepareResult(NamedTuple): + q_pe: torch.Tensor + k_pe: torch.Tensor + q_nope_out: torch.Tensor + k_nope: torch.Tensor + forward_batch: Any + zero_allocator: Any + positions: torch.Tensor + topk_indices: Optional[torch.Tensor] + llama_4_scaling: Optional[Any] + + +# --------------------------------------------------------------------------- +# Init helpers +# --------------------------------------------------------------------------- +def init_sgl_attrs(attn: DeepseekV2MLAAttention, config) -> None: + """Initialise sglang-only attributes on DeepseekV2MLAAttention.""" + from sglang.srt.configs.model_config import is_deepseek_nsa + + attn.use_nsa = is_deepseek_nsa(config) + attn.use_deep_gemm_bmm = False + attn.alt_stream = None + attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 + attn.w_kc, attn.w_vc = None, None + attn.w_scale = None + attn.w_scale_k = None + attn.w_scale_v = None + + +# --------------------------------------------------------------------------- +# Absorbed batched-matmul (shared by prepare and core) +# --------------------------------------------------------------------------- +def mla_absorbed_bmm( + attn: DeepseekV2MLAAttention, + inp: torch.Tensor, + weight: torch.Tensor, + weight_scale: Optional[torch.Tensor], + weight_scale_k: Optional[torch.Tensor], + out_dim: int, +) -> torch.Tensor: + """Batched matmul for MLA absorbed weights (w_kc / w_vc). + + Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. + inp: (num_tokens, num_heads, in_dim) — token-major + Returns: (num_tokens, num_heads, out_dim) — token-major + """ + if attn.use_deep_gemm_bmm: + from sglang.srt.layers import deep_gemm_wrapper + + val, scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) + ) + out = inp.new_empty((attn.num_local_heads, aligned_m, out_dim)) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (val, scale), + (weight, weight_scale_k), + out, + masked_m, + expected_m, + ) + return out[:, :expected_m, :].transpose(0, 1) + + if _is_hip: + if _use_aiter_gfx95 and weight.dtype == torch.uint8: + x = inp.transpose(0, 1) + out = torch.empty( + x.shape[0], + x.shape[1], + weight.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + weight.transpose(-2, -1), + weight_scale_k.transpose(-2, -1), + torch.bfloat16, + out, + ) + return out.transpose(0, 1) + + if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz + ): + out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=inp, + WQ=weight.transpose(-1, -2), + w_scale=weight_scale, + group_size=128, + YQ=None, + transpose_bm=False, + transpose_bm_in=True, + dtype=torch.bfloat16, + ) + return out.transpose(0, 1) + + out = torch.bmm( + inp.to(torch.bfloat16).transpose(0, 1), + weight.to(torch.bfloat16) * weight_scale, + ) + return out.transpose(0, 1) + + # CUDA fp8 path + if weight.dtype == torch.float8_e4m3fn: + val, scale = per_tensor_quant_mla_fp8( + inp.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=inp.device), + ) + out = bmm_fp8(val, weight, scale, weight_scale, torch.bfloat16) + return out.transpose(0, 1) + + # bf16 fallback + return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) + + +# --------------------------------------------------------------------------- +# Forward: prepare → core +# --------------------------------------------------------------------------- +def forward_sgl_prepare( + attn: DeepseekV2MLAAttention, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs, +) -> SglPrepareResult: + """Prepare QKV for sglang MLA attention (adapted from sglang forward_absorb_prepare).""" + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + + forward_batch = model_kwargs.get("forward_batch", None) + zero_allocator = model_kwargs.get("zero_allocator", None) + llama_4_scaling = model_kwargs.get("llama_4_scaling", None) + q_lora = None + topk_indices = None + + if attn.q_lora_rank is not None: + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], + dim=-1, + ) + ) + + if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: + qkv_lora = torch.cat([q, latent_cache], dim=-1) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] < positions.shape[0]: + raise RuntimeError( + f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {positions.shape[0]}" + ) + qkv_lora = qkv_lora[: positions.shape[0]] + q, latent_cache = torch.split( + qkv_lora, + [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], + dim=-1, + ) + + k_nope = latent_cache[..., : attn.kv_lora_rank] + + # overlap qk norm + if attn.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + attn.alt_stream.wait_stream(current_stream) + q = attn.q_a_layernorm(q) + with torch.cuda.stream(attn.alt_stream): + k_nope = attn.kv_a_layernorm(k_nope) + current_stream.wait_stream(attn.alt_stream) + else: + q = attn.q_a_layernorm(q) + k_nope = attn.kv_a_layernorm(k_nope) + + if attn.use_nsa: + if q_lora is None: + q_lora = q + + # overlap q_b_proj and indexer during decode + if ( + attn.alt_stream is not None + and get_is_capture_mode() + and forward_batch.forward_mode.is_decode_or_idle() + and q_lora is not None + ): + current_stream = torch.cuda.current_stream() + attn.alt_stream.wait_stream(current_stream) + with torch.cuda.stream(attn.alt_stream): + k_nope = k_nope.unsqueeze(1) + q = attn.q_b_proj(q).view(-1, attn.num_local_heads, attn.qk_head_dim) + topk_indices = attn.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=attn.layer_num, + ) + current_stream.wait_stream(attn.alt_stream) + else: + k_nope = k_nope.unsqueeze(1) + q = attn.q_b_proj(q).view(-1, attn.num_local_heads, attn.qk_head_dim) + if q_lora is not None: + topk_indices = attn.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=attn.layer_num, + ) + else: + q = attn.q_proj(hidden_states).view(-1, attn.num_local_heads, attn.qk_head_dim) + latent_cache = attn.kv_a_proj_with_mqa(hidden_states)[0] + k_nope = latent_cache[..., : attn.kv_lora_rank] + k_nope = attn.kv_a_layernorm(k_nope).unsqueeze(1) + + q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) + k_pe = latent_cache[..., attn.kv_lora_rank :].unsqueeze(1) + + q_nope_out = mla_absorbed_bmm( + attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank + ) + + if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + k_nope, k_pe = attn.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + + return SglPrepareResult( + q_pe=q_pe, + k_pe=k_pe, + q_nope_out=q_nope_out, + k_nope=k_nope, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + positions=positions, + topk_indices=topk_indices, + llama_4_scaling=llama_4_scaling, + ) + + +def forward_sgl_core( + attn: DeepseekV2MLAAttention, + prepared: SglPrepareResult, +) -> torch.Tensor: + """Core MLA attention computation for sglang (adapted from sglang forward_absorb_core).""" + save_kv_cache = True + + if attn.use_fused_qk_rope_concat_and_cache_mla: + cos = attn.rotary_emb.cos_cache + sin = attn.rotary_emb.sin_cache + kv_cache = prepared.forward_batch.token_to_kv_pool.get_key_buffer(attn.layer_num) + k_scale = attn.mla_attn.attn.k_scale + + q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( + prepared.q_nope_out, + prepared.q_pe, + prepared.k_nope, + prepared.k_pe, + kv_cache, + prepared.forward_batch.out_cache_loc, + prepared.positions, + cos, + sin, + k_scale, + attn.rotary_emb.is_neox_style, + q_out_dtype=prepared.q_nope_out.dtype, + ) + k = torch.cat([prepared.k_nope, k_pe_roped], dim=-1) + save_kv_cache = False + else: + q = torch.cat([prepared.q_nope_out, prepared.q_pe], dim=-1) + k = torch.cat([prepared.k_nope, prepared.k_pe], dim=-1) + + if prepared.llama_4_scaling is not None: + q = q * prepared.llama_4_scaling + + extra_kwargs = {} + if prepared.topk_indices is not None: + extra_kwargs["topk_indices"] = prepared.topk_indices + + attn_output = attn.mla_attn( + q, + k, + prepared.k_nope, + forward_batch=prepared.forward_batch, + save_kv_cache=save_kv_cache, + **extra_kwargs, + ) + attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) + + # up-proj by w_vc + attn_bmm_output = mla_absorbed_bmm( + attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim + ).flatten(1, 2) + + return attn.o_proj(attn_bmm_output) + + +def prepare_qkv_latent( + attn: DeepseekV2MLAAttention, + hidden_states: torch.Tensor, + forward_batch, +) -> torch.Tensor: + """Prepare QKV latent tensor for the sglang communicator.""" + assert attn.q_lora_rank is not None + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) + + # Fallback: when communicator does not enable input_scattered gather, + # force qkv latent token dimension to align with positions. + expected_tokens = 0 + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + expected_tokens = int(forward_batch.positions.shape[0]) + if expected_tokens <= 0: + expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) + + if ( + expected_tokens > 0 + and qkv_lora.shape[0] != expected_tokens + and get_tensor_model_parallel_world_size() > 1 + ): + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] > expected_tokens: + qkv_lora = qkv_lora[:expected_tokens] + elif qkv_lora.shape[0] < expected_tokens: + raise RuntimeError( + f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {expected_tokens}" + ) + return qkv_lora + + +# --------------------------------------------------------------------------- +# Top-level forward entry point +# --------------------------------------------------------------------------- +def forward_sgl_plugin_mode( + attn: DeepseekV2MLAAttention, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs, +) -> torch.Tensor: + """Full MLA forward in sglang plugin mode.""" + forward_batch = model_kwargs.get("forward_batch", None) + if forward_batch is None: + raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") + + attn_tp_context = get_attn_tp_context() + with attn_tp_context.maybe_input_scattered(forward_batch): + if attn.q_lora_rank is not None: + attn_tp_context.set_attn_inputs( + AttentionInputs( + hidden_states, + forward_batch, + lambda hs, fb: prepare_qkv_latent(attn, hs, fb), + ) + ) + prepared = forward_sgl_prepare(attn, positions, hidden_states, **model_kwargs) + return forward_sgl_core(attn, prepared) + + +# --------------------------------------------------------------------------- +# Weight post-processing: decomposed into sub-functions +# --------------------------------------------------------------------------- +def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: + """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" + if hasattr(attn.kv_b_proj, "qweight"): + awq_dequant = awq_dequantize_func() + if awq_dequant is None: + raise ValueError("AWQ dequantize function is not supported for current device") + w = awq_dequant( + attn.kv_b_proj.qweight, + attn.kv_b_proj.scales, + attn.kv_b_proj.qzeros, + ).T + else: + w = attn.kv_b_proj.weight + + # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes. + # View-cast back to fn so the normalize path works correctly. + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: + w = w.view(torch.float8_e4m3fn) + + return w + + +def _get_weight_block_size(attn: DeepseekV2MLAAttention) -> Optional[list[int]]: + """Derive weight_block_size from ATOM's quant_type system.""" + from aiter import QuantType as _AiterQuantType + + qt = getattr(attn.kv_b_proj, "quant_type", None) + if qt == _AiterQuantType.per_1x128: + return [128, 128] + elif qt == _AiterQuantType.per_1x32: + return [1, 32] + return None + + +def _process_fp8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: + """Process FP8 weights for kv_b_proj. + + Returns (w, use_deep_gemm_bmm, block_scale). + """ + from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz + from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + ) + from sglang.srt.layers.deep_gemm_wrapper import ENABLE_JIT_DEEPGEMM, DEEPGEMM_BLACKWELL + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 + + use_deep_gemm_bmm = False + block_scale = None + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") or hasattr( + attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + attn.kv_b_proj.weight_scale + if hasattr(attn.kv_b_proj, "weight_scale") + else attn.kv_b_proj.weight_scale_inv + ) + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=weight_scale, input_scale=None + ) + else: + weight = w + + if ( + should_deepgemm_weight_requant_ue8m0(weight_block_size=weight_block_size) + and getattr(weight_scale, "format_ue8m0", False) + ): + weight_scale = inverse_transform_scale_ue8m0( + weight_scale, mn=weight.shape[-2] + ) + + if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: + if ( + ENABLE_JIT_DEEPGEMM + and not DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, weight_scale, weight_block_size, torch.bfloat16 + ) + else: + w, scale = block_quant_to_tensor_quant(weight, weight_scale, weight_block_size) + attn.w_scale = scale + else: + if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=attn.kv_b_proj.weight_scale, input_scale=None + ) + else: + weight = w + weight_scale = attn.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + attn.w_scale = scale + + return w, use_deep_gemm_bmm, block_scale + + +def _process_int8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> torch.Tensor: + """Process INT8 weights for kv_b_proj.""" + from sglang.srt.layers.quantization.int8_utils import block_dequant as int8_block_dequant + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") + return int8_block_dequant( + w, attn.kv_b_proj.weight_scale_inv, weight_block_size + ).to(torch.bfloat16) + else: + return w.to(torch.bfloat16) * attn.kv_b_proj.weight_scale.to(torch.bfloat16) + + +def _split_and_assign_kc_vc( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + use_deep_gemm_bmm: bool, + block_scale: Optional[torch.Tensor], + weight_block_size: Optional[list[int]], +) -> None: + """Split weight into kc/vc and assign to attn.""" + from atom.model_ops.utils import quark_post_load_weights + + w_kc, w_vc = w.unflatten( + 0, (-1, attn.qk_nope_head_dim + attn.v_head_dim) + ).split([attn.qk_nope_head_dim, attn.v_head_dim], dim=1) + + # quark fp4 special path + quant_method = getattr(attn.kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + if ( + _use_aiter_gfx95 + and quant_config is not None + and quant_config.get_name() == "quark" + ): + w_kc, attn.w_scale_k, w_vc, attn.w_scale_v = quark_post_load_weights( + attn, w, "mxfp4" + ) + + if not use_deep_gemm_bmm: + attn.w_kc = bind_or_assign( + attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + attn.w_vc = bind_or_assign(attn.w_vc, w_vc) + + if hasattr(attn.kv_b_proj, "weight_scale") and attn.w_scale is None: + attn.w_scale = bind_or_assign(attn.w_scale, attn.kv_b_proj.weight_scale) + if _is_hip: + attn.w_scale *= 2.0 + + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + attn.w_kc = attn.w_kc.to(torch.bfloat16) * attn.w_scale + attn.w_vc = attn.w_vc.to(torch.bfloat16) * attn.w_scale + else: + num_tiles_k = attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + + attn.w_scale_k = bind_or_assign(attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()) + attn.w_scale_v = bind_or_assign(attn.w_scale_v, ws_vc.contiguous()) + attn.w_kc = bind_or_assign(attn.w_kc, w_kc.transpose(1, 2).contiguous()) + attn.w_vc = bind_or_assign(attn.w_vc, w_vc.contiguous()) + attn.use_deep_gemm_bmm = True + + +def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: + """Process kv_b_proj weights after loading for sglang MLA mode. + + Orchestrates reading, quantization handling, and splitting of + kv_b_proj into absorbed w_kc / w_vc weights. + """ + w = _read_kv_b_proj_weight(attn) + weight_block_size = _get_weight_block_size(attn) + + use_deep_gemm_bmm = False + block_scale = None + + # fp8 path + if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( + attn, w, weight_block_size + ) + + # int8 path + if w.dtype == torch.int8: + w = _process_int8_weight(attn, w, weight_block_size) + + # split and assign kc/vc + _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) From 6c765ad306ed895679b3b28d64b2546100f60717 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Thu, 26 Mar 2026 06:31:22 +0000 Subject: [PATCH 13/15] add comments Signed-off-by: zhuyuhua-v --- atom/model_ops/radix_attention.py | 10 +++++-- atom/models/deepseek_v2.py | 5 +++- atom/models/qwen3_moe.py | 29 +++++++++---------- .../attention_backend/sgl_attn_backend.py | 18 ++++++------ atom/plugin/register.py | 10 +++++-- atom/plugin/sglang/mla.py | 1 - 6 files changed, 43 insertions(+), 30 deletions(-) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index d0a9f0453..839e33629 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -13,8 +13,10 @@ class RadixAttention(BaseAttention): - """ - Attention radix implementation + """Attention wrapper for sglang plugin mode. + + Delegates to sglang's RadixAttention internally, adapting ATOM's + attention interface to sglang's forward_batch-based API. """ def __init__( @@ -63,6 +65,8 @@ def __init__( v_head_dim=_v_head_dim, prefix=maybe_prefix(prefix, "attn"), ) + # sglang's RadixAttention expects k_scale/v_scale on device; + # ensure they exist with identity scaling for non-quantised KV cache. if self.attn.k_scale is None: self.attn.k_scale = torch.nn.Parameter( torch.tensor([1.0], dtype=torch.float32, device="cuda"), @@ -96,6 +100,8 @@ def forward_impl_plugin_mode( if is_sglang(): # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) + # When fused rope+qknorm is active, KV cache is saved inside the + # fused kernel, so we skip the separate save step in sglang's attn. save_kv_cache = kwargs.get("save_kv_cache", not self.use_aiter_rope_fused_qknorm) assert forward_batch is not None, "forward_batch is required for sglang" # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 6d2bb86ca..810cf16a8 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1559,13 +1559,15 @@ def forward( hidden_states: torch.Tensor, **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: + # Sglang plugin mode uses its own forward path with absorbed MLA weights + # and sglang-specific attention backend. See atom/plugin/sglang/mla.py. if is_sglang(): from atom.plugin.sglang.mla import forward_sgl_plugin_mode return forward_sgl_plugin_mode(self, positions, hidden_states, **model_kwargs) return self.forward_common(positions, hidden_states, **model_kwargs) def process_weights_after_loading(self) -> None: - # only for sglang plugin mode + """Post-load hook: split kv_b_proj into absorbed w_kc / w_vc for sglang MLA.""" if not is_sglang(): return from atom.plugin.sglang.mla import process_mla_kv_b_proj_after_loading @@ -1923,6 +1925,7 @@ def __init__( self.model.make_empty_intermediate_tensors ) + # Initialise sglang's TP attention context for MLA gather/scatter. if is_sglang(): from sglang.srt.configs.model_config import is_deepseek_nsa from sglang.srt.layers.communicator import get_attn_tp_context diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 421a68c78..714b21129 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -241,6 +241,7 @@ def forward_sgl_plugin_mode( qkv: torch.Tensor, **model_kwargs: dict[str, Any] | None, ): + """Sglang forward path: fused rope+qknorm+cache or split+norm+rope.""" if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: forward_batch = model_kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" @@ -299,7 +300,6 @@ def forward( **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) - q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: q, k, v = torch.split( qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 @@ -307,19 +307,21 @@ def forward( attn_output = self.attn( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) + elif is_sglang(): + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) else: - if is_sglang(): - attn_output = self.forward_sgl_plugin_mode( - positions, qkv, **model_kwargs - ) - else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output @@ -468,9 +470,6 @@ def forward( inputs_embeds: torch.Tensor | None = None, **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: - # import logging - # logger = logging.getLogger("atom.models.qwen3_moe") - # logger.info(f"atom call Qwen3MoeModel") if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py index 213c9db38..8bd49661c 100644 --- a/atom/plugin/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -162,6 +162,7 @@ def reshape_and_cache_shuffle_triton( @dataclass class ForwardMetadata: + """Per-batch metadata consumed by ATOM's attention kernels (pa_fwd_asm, mla_decode_fwd, etc.).""" # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode kv_indptr: Optional[torch.Tensor] kv_indices: Optional[torch.Tensor] @@ -191,6 +192,14 @@ class ForwardMetadata: class ATOMAttnBackendForSgl(AiterAttnBackend): + """ATOM's custom attention backend for sglang plugin mode. + + Extends sglang's AiterAttnBackend with ATOM-specific optimisations: + page-table management, pa_persistent_fwd decode path, and MLA + prefill kernels (fp8, decompress, absorbed). Registered to sglang + via atom.plugin.register._register_custom_attention_to_sglang(). + """ + def __init__( self, model_runner: ModelRunner, @@ -644,15 +653,6 @@ def _build_pa_metadata_for_prefill(self, batch_size: int): self.pa_kv_indices, page_table.stride(0), ) - # kv_indices = self.pa_kv_indices - - # Compute kv_last_page_lens for each sequence - # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() - - # Store in ForwardMetadata for reuse in forward_extend - # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr - # self.forward_metadata.prefill_kv_indices = kv_indices - # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens def init_cuda_graph_state( self, diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 55ee75bfb..5b824d3bb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -18,7 +18,11 @@ def _register_custom_attention_to_sglang() -> None: + """Override sglang's built-in "aiter" attention backend with ATOM's implementation. + sglang only accepts pre-registered backend names, so we reuse the "aiter" + name to inject ATOMAttnBackendForSgl without modifying sglang source. + """ from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) @@ -43,8 +47,10 @@ def register_ops_to_sglang(atom_config: Config) -> None: def set_attn_cls() -> None: - """ - Set the attention class for constructing the model based on the framework + """Swap ``atom.model_ops.Attention`` to the framework-appropriate class. + + ATOM models reference ``ops.Attention`` generically; this function binds + it to PagedAttention (vLLM) or RadixAttention (sglang) at plugin init time. """ import atom.model_ops as ops diff --git a/atom/plugin/sglang/mla.py b/atom/plugin/sglang/mla.py index e7d52a9c7..334381157 100644 --- a/atom/plugin/sglang/mla.py +++ b/atom/plugin/sglang/mla.py @@ -25,7 +25,6 @@ _is_cuda, _is_fp8_fnuz, _is_npu, - _use_aiter_gfx95, awq_dequantize_func, ) from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( From d994d677f22101b91fbab8d78edffc4c38496a48 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Thu, 26 Mar 2026 06:39:32 +0000 Subject: [PATCH 14/15] change sglang args and qwen page size Signed-off-by: zhuyuhua-v --- atom/models/qwen3_moe.py | 24 +++++++++++++----------- atom/plugin/config.py | 15 ++++----------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 714b21129..7689f4474 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -46,6 +46,18 @@ ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM +def _get_page_size(forward_batch, default: int = 1024) -> int: + """Resolve page_size from forward_batch's attn_backend or token pool.""" + for obj in ( + getattr(forward_batch, "attn_backend", None), + getattr(getattr(forward_batch, "token_to_kv_pool", None), "allocator", None), + getattr(forward_batch, "token_to_kv_pool", None), + ): + if obj is not None and hasattr(obj, "page_size"): + return obj.page_size + return default + + class Qwen3MoeMLP(nn.Module): def __init__( self, @@ -248,17 +260,7 @@ def forward_sgl_plugin_mode( k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( self.layer_num ) - block_size = 1024 # Default fallback - if hasattr(forward_batch, "attn_backend") and hasattr( - forward_batch.attn_backend, "page_size" - ): - block_size = forward_batch.attn_backend.page_size - elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( - forward_batch.token_to_kv_pool.allocator, "page_size" - ): - block_size = forward_batch.token_to_kv_pool.allocator.page_size - elif hasattr(forward_batch.token_to_kv_pool, "page_size"): - block_size = forward_batch.token_to_kv_pool.page_size + block_size = _get_page_size(forward_batch) x = 16 // k_buffer.element_size() aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( kv_cache=(k_buffer, v_buffer), diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 88548e8c2..68eac74d8 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -1,5 +1,3 @@ -import sys - from typing import Any, Optional from dataclasses import dataclass @@ -110,8 +108,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.server_args import ( - ServerArgs, - prepare_server_args, + get_global_server_args, PortArgs, ) from sglang.srt.configs.model_config import ModelConfig as SglangModelConfig @@ -119,13 +116,9 @@ def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.configs.load_config import LoadConfig from atom.config import Config, ParallelConfig, CompilationConfig - # Format1: sglang serve --model-path ... - # Format2: python3 -m sglang.launch_server --model-path ... - args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] - # sglang has no global config variable like vllm, - # so here construct the server args from sys.argv passed by users - # this is the only way to get full arguments - server_args: ServerArgs = prepare_server_args(args_list) + # sglang's ModelRunner already parsed and stored ServerArgs globally + # before OOT model loading, so we can retrieve it directly. + server_args = get_global_server_args() sgl_model_config = SglangModelConfig.from_server_args(server_args) sgl_model_opt_config = ModelOptConfig( From 5bf05b0f30c57841dd552b59c052f6dcc844f4f8 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Thu, 26 Mar 2026 08:36:02 +0000 Subject: [PATCH 15/15] remove redundant code Signed-off-by: zhuyuhua-v --- atom/model_ops/__init__.py | 1 - atom/model_ops/linear.py | 1 - atom/models/deepseek_v2.py | 9 ++++----- atom/plugin/attention_backend/__init__.py | 0 atom/plugin/register.py | 2 +- .../sglang/{mla.py => sgl_attention_mla.py} | 15 +-------------- .../sgl_attn_backend.py | 0 7 files changed, 6 insertions(+), 22 deletions(-) delete mode 100644 atom/plugin/attention_backend/__init__.py rename atom/plugin/sglang/{mla.py => sgl_attention_mla.py} (95%) rename atom/plugin/{attention_backend => sglang}/sgl_attn_backend.py (100%) diff --git a/atom/model_ops/__init__.py b/atom/model_ops/__init__.py index 4e0c0c258..4b6c0b545 100644 --- a/atom/model_ops/__init__.py +++ b/atom/model_ops/__init__.py @@ -5,7 +5,6 @@ # it can be assigned to different attention ops. # By default, PagedAttention is used. # For sglang, RadixAttention will be assigned to Attention -# see register.py for details. Attention = PagedAttention __all__ = [ diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a25f4070d..f0d73f838 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -394,7 +394,6 @@ def process_weights_after_loading(self): if self.quant_type == QuantType.per_1x32: self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data) - @mark_trace def forward( self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16 diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 383029b11..a754dfd4d 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -1456,7 +1456,7 @@ def __init__( # sglang plugin mode attributes (lazily initialised) if is_sglang(): - from atom.plugin.sglang.mla import init_sgl_attrs + from atom.plugin.sglang.sgl_attention_mla import init_sgl_attrs init_sgl_attrs(self, config) @@ -1571,9 +1571,9 @@ def forward( **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: # Sglang plugin mode uses its own forward path with absorbed MLA weights - # and sglang-specific attention backend. See atom/plugin/sglang/mla.py. + # and sglang-specific attention backend. See atom/plugin/sglang/sgl_attention_mla.py. if is_sglang(): - from atom.plugin.sglang.mla import forward_sgl_plugin_mode + from atom.plugin.sglang.sgl_attention_mla import forward_sgl_plugin_mode return forward_sgl_plugin_mode(self, positions, hidden_states, **model_kwargs) return self.forward_common(positions, hidden_states, **model_kwargs) @@ -1581,7 +1581,7 @@ def process_weights_after_loading(self) -> None: """Post-load hook: split kv_b_proj into absorbed w_kc / w_vc for sglang MLA.""" if not is_sglang(): return - from atom.plugin.sglang.mla import process_mla_kv_b_proj_after_loading + from atom.plugin.sglang.sgl_attention_mla import process_mla_kv_b_proj_after_loading process_mla_kv_b_proj_after_loading(self) @@ -1786,7 +1786,6 @@ def __init__( prefix: str = "", layer_type: type[nn.Module] = DeepseekV2DecoderLayer, ): - # logger.info(f"atom call DeepseekV2Model") super().__init__() config = atom_config.hf_config diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/attention_backend/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 5b824d3bb..8ca19e7b1 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -34,7 +34,7 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl + from atom.plugin.sglang.sgl_attn_backend import ATOMAttnBackendForSgl return ATOMAttnBackendForSgl(runner) diff --git a/atom/plugin/sglang/mla.py b/atom/plugin/sglang/sgl_attention_mla.py similarity index 95% rename from atom/plugin/sglang/mla.py rename to atom/plugin/sglang/sgl_attention_mla.py index 334381157..c54bce421 100644 --- a/atom/plugin/sglang/mla.py +++ b/atom/plugin/sglang/sgl_attention_mla.py @@ -43,9 +43,8 @@ if TYPE_CHECKING: from atom.models.deepseek_v2 import DeepseekV2MLAAttention -# --------------------------------------------------------------------------- + # bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) -# --------------------------------------------------------------------------- if _is_cuda: from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 from sglang.srt.utils.custom_op import register_custom_op @@ -76,9 +75,7 @@ def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") -# --------------------------------------------------------------------------- # NamedTuple for prepare → core data flow -# --------------------------------------------------------------------------- class SglPrepareResult(NamedTuple): q_pe: torch.Tensor k_pe: torch.Tensor @@ -91,9 +88,7 @@ class SglPrepareResult(NamedTuple): llama_4_scaling: Optional[Any] -# --------------------------------------------------------------------------- # Init helpers -# --------------------------------------------------------------------------- def init_sgl_attrs(attn: DeepseekV2MLAAttention, config) -> None: """Initialise sglang-only attributes on DeepseekV2MLAAttention.""" from sglang.srt.configs.model_config import is_deepseek_nsa @@ -108,9 +103,7 @@ def init_sgl_attrs(attn: DeepseekV2MLAAttention, config) -> None: attn.w_scale_v = None -# --------------------------------------------------------------------------- # Absorbed batched-matmul (shared by prepare and core) -# --------------------------------------------------------------------------- def mla_absorbed_bmm( attn: DeepseekV2MLAAttention, inp: torch.Tensor, @@ -194,9 +187,7 @@ def mla_absorbed_bmm( return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) -# --------------------------------------------------------------------------- # Forward: prepare → core -# --------------------------------------------------------------------------- def forward_sgl_prepare( attn: DeepseekV2MLAAttention, positions: torch.Tensor, @@ -416,9 +407,7 @@ def prepare_qkv_latent( return qkv_lora -# --------------------------------------------------------------------------- # Top-level forward entry point -# --------------------------------------------------------------------------- def forward_sgl_plugin_mode( attn: DeepseekV2MLAAttention, positions: torch.Tensor, @@ -444,9 +433,7 @@ def forward_sgl_plugin_mode( return forward_sgl_core(attn, prepared) -# --------------------------------------------------------------------------- # Weight post-processing: decomposed into sub-functions -# --------------------------------------------------------------------------- def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" if hasattr(attn.kv_b_proj, "qweight"): diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/sgl_attn_backend.py similarity index 100% rename from atom/plugin/attention_backend/sgl_attn_backend.py rename to atom/plugin/sglang/sgl_attn_backend.py