diff --git a/atom/config.py b/atom/config.py index 0bb018b69..fbe67341d 100644 --- a/atom/config.py +++ b/atom/config.py @@ -887,24 +887,31 @@ def __post_init__(self): # assert os.path.isdir(self.model) assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config( - self.model, trust_remote_code=self.trust_remote_code - ) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config( + self.model, trust_remote_code=self.trust_remote_code + ) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) - if rope_params is None: - rope_params = {} - rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) - rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} + rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = rope_params.get("rope_type", "default") self.hf_config.rope_parameters = rope_params - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = QuantizationConfig( self.hf_config, self.plugin_config.vllm_config if self.plugin_config is not None else None, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 5642d0b8e..58c29b6de 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1406,8 +1406,7 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, ) else: - # Direct kernel call for non-EP/DP cases - return rocm_asm_moe_impl( + return torch.ops.aiter.rocm_aiter_fused_moe( x, layer.w13_weight, layer.w2_weight, @@ -1698,13 +1697,26 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, per_act_token_quant=True, ) + elif self.block_quant: + if self.quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif self.quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=block_shape, + ) else: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=None, ) @mark_trace(prefix="fp8_moe", torch_compile=False) diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 25388b384..839e33629 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -9,11 +9,14 @@ from .base_attention import BaseAttention from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix +from atom.utils import envs class RadixAttention(BaseAttention): - """ - Attention radix implementation + """Attention wrapper for sglang plugin mode. + + Delegates to sglang's RadixAttention internally, adapting ATOM's + attention interface to sglang's forward_batch-based API. """ def __init__( @@ -47,23 +50,39 @@ def __init__( prefix=prefix, **kwargs, ) - self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention + _v_head_dim = mla_modules.kv_lora_rank if (use_mla and mla_modules is not None) else head_dim + self.attn = RadixAttention( num_heads=num_heads, head_dim=head_dim, scaling=scale, num_kv_heads=num_kv_heads, layer_id=layer_num, + v_head_dim=_v_head_dim, prefix=maybe_prefix(prefix, "attn"), ) + # sglang's RadixAttention expects k_scale/v_scale on device; + # ensure they exist with identity scaling for non-quantised KV cache. + if self.attn.k_scale is None: + self.attn.k_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, + ) + if self.attn.v_scale is None: + self.attn.v_scale = torch.nn.Parameter( + torch.tensor([1.0], dtype=torch.float32, device="cuda"), + requires_grad=False, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) + # if True, save cache will be done in rope + self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM def forward_impl_plugin_mode( self, @@ -81,11 +100,18 @@ def forward_impl_plugin_mode( if is_sglang(): # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) + # When fused rope+qknorm is active, KV cache is saved inside the + # fused kernel, so we skip the separate save step in sglang's attn. + save_kv_cache = kwargs.get("save_kv_cache", not self.use_aiter_rope_fused_qknorm) assert forward_batch is not None, "forward_batch is required for sglang" - if self.rotary_emb is not None: - assert positions is not None, "positions is required for ROPE" - query, key = self.rotary_emb(positions, query, key) - return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + return self.attn( + query, + key, + value, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index c7e59d7b4..a754dfd4d 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -23,8 +23,9 @@ # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" +import json import logging -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Iterable, Any import torch from aiter import ( @@ -38,7 +39,7 @@ top_k_per_row_prefill, ) from aiter.dist.communication_op import tensor_model_parallel_all_reduce -from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size, get_tp_group from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits from aiter.ops.triton.fused_fp8_quant import ( @@ -50,6 +51,7 @@ fused_reduce_rms_mxfp4_quant, fused_rms_mxfp4_quant, ) +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits from aiter.rotary_embedding import get_rope from atom.config import Config, QuantizationConfig, get_current_atom_config @@ -66,9 +68,9 @@ RowParallelLinear, use_triton_gemm, ) +from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE, _has_module, quark_post_load_weights from atom.model_ops.moe import FusedMoE from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled -from atom.model_ops.utils import MXFP4_QUANT_BLOCK_SIZE from atom.models.utils import ( IntermediateTensors, PPMissingLayer, @@ -82,9 +84,7 @@ from atom.utils.forward_context import get_forward_context from torch import nn from transformers import PretrainedConfig - -# from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 - +from atom.plugin.prepare import is_sglang logger = logging.getLogger("atom") if use_triton_gemm(): @@ -106,6 +106,7 @@ gemm_a8w8_blockscale_preshuffle = None gemm_a16w8_blockscale_preshuffle = None + ENABLE_DS_QKNORM_QUANT_FUSION = envs.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION ENABLE_DS_QKNORM_FUSION = envs.ATOM_ENABLE_DS_QKNORM_FUSION ENABLE_ALLREDUCE_RMSNORM_FUSION = envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION @@ -1453,11 +1454,18 @@ def __init__( self.quant_dtype = layer_quant_dtype self.fuse_qknorm_quant = True - def forward( + # sglang plugin mode attributes (lazily initialised) + if is_sglang(): + from atom.plugin.sglang.sgl_attention_mla import init_sgl_attrs + + init_sgl_attrs(self, config) + + def forward_common( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + **model_kwargs: dict[str, Any] | None + ): hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states @@ -1556,6 +1564,26 @@ def forward( hidden_states_or_q_c_scale, ) + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs: dict[str, Any] | None + ) -> torch.Tensor: + # Sglang plugin mode uses its own forward path with absorbed MLA weights + # and sglang-specific attention backend. See atom/plugin/sglang/sgl_attention_mla.py. + if is_sglang(): + from atom.plugin.sglang.sgl_attention_mla import forward_sgl_plugin_mode + return forward_sgl_plugin_mode(self, positions, hidden_states, **model_kwargs) + return self.forward_common(positions, hidden_states, **model_kwargs) + + def process_weights_after_loading(self) -> None: + """Post-load hook: split kv_b_proj into absorbed w_kc / w_vc for sglang MLA.""" + if not is_sglang(): + return + from atom.plugin.sglang.sgl_attention_mla import process_mla_kv_b_proj_after_loading + process_mla_kv_b_proj_after_loading(self) + class DeepseekV2DecoderLayer(nn.Module): @@ -1661,14 +1689,13 @@ def __init__( self.fuse_rmsnorm_quant = ( ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION and self.quant_dtype is not None ) - def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + **model_kwargs: dict[str, Any] | None ) -> torch.Tensor: - # Self Attention if self.fuse_input_norm_quant: assert self.quant_dtype is not None weight = self.input_layernorm.weight @@ -1723,6 +1750,7 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + **model_kwargs, ) if hidden_states.dtype == torch.float16: @@ -1828,6 +1856,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -1841,7 +1870,7 @@ def forward( residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer(positions, hidden_states, residual, **model_kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1880,6 +1909,7 @@ def __init__( quant_config = atom_config.quant_config self.config = config self.quant_config = quant_config + self.atom_config = atom_config if hasattr(config, "q_lora_rank") and config.q_lora_rank is not None: self.packed_modules_mapping = { @@ -1913,6 +1943,12 @@ def __init__( self.model.make_empty_intermediate_tensors ) + # Initialise sglang's TP attention context for MLA gather/scatter. + if is_sglang(): + from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.layers.communicator import get_attn_tp_context + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -1922,9 +1958,11 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + **model_kwargs: dict[str, Any] | None ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds + input_ids, positions, intermediate_tensors, inputs_embeds, + **model_kwargs, ) return hidden_states @@ -1952,6 +1990,19 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + + # lazy import to avoid circular import issue since model_loader also imports model.. + from atom.model_loader.loader import load_model_in_plugin_mode + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index acc292bb1..3e589606c 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -1,11 +1,11 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Iterable import torch from aiter.dist.communication_op import tensor_model_parallel_all_reduce from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope +from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -33,6 +33,8 @@ ) from atom.utils import envs from torch import nn +from atom.model_loader.loader import load_model_in_plugin_mode +from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -41,6 +43,19 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM + + +def _get_page_size(forward_batch, default: int = 1024) -> int: + """Resolve page_size from forward_batch's attn_backend or token pool.""" + for obj in ( + getattr(forward_batch, "attn_backend", None), + getattr(getattr(forward_batch, "token_to_kv_pool", None), "allocator", None), + getattr(forward_batch, "token_to_kv_pool", None), + ): + if obj is not None and hasattr(obj, "page_size"): + return obj.page_size + return default class Qwen3MoeMLP(nn.Module): @@ -230,6 +245,56 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + def forward_sgl_plugin_mode( + self, + positions: torch.Tensor, + qkv: torch.Tensor, + **model_kwargs: dict[str, Any] | None, + ): + """Sglang forward path: fused rope+qknorm+cache or split+norm+rope.""" + if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: + forward_batch = model_kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) + block_size = _get_page_size(forward_batch) + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) + return attn_output def forward( self, @@ -238,7 +303,6 @@ def forward( **model_kwargs: dict[str, Any] | None, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) - q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1) if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION: q, k, v = torch.split( qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 @@ -246,6 +310,10 @@ def forward( attn_output = self.attn( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) + elif is_sglang(): + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) else: # Add qk-norm (per-head) q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( @@ -271,7 +339,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = rope_params + rope_scaling = None if rope_params["rope_type"] == "default" else rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix @@ -353,7 +421,7 @@ def forward( @support_torch_compile -class Qwen3MoeModel(nn.Module): +class Qwen3MoeModel(torch.nn.Module): def __init__( self, atom_config: Config, @@ -527,3 +595,14 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # load weights in plugin mode and discard passed weights generator + # here prefix is "model." because Qwen3MoeForCausalLM is constructed in model + # wrapper class, so the name of loaded weights are prefixed with "model.". + # The vLLM will check the name of the loaded weights to make sure all the + # weights are loaded correctly + loaded_weights_record = load_model_in_plugin_mode( + model=self, config=self.atom_config, prefix="model." + ) + return loaded_weights_record diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 9ffc6392a..68eac74d8 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -1,5 +1,3 @@ -import sys - from typing import Any, Optional from dataclasses import dataclass @@ -110,8 +108,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.server_args import ( - ServerArgs, - prepare_server_args, + get_global_server_args, PortArgs, ) from sglang.srt.configs.model_config import ModelConfig as SglangModelConfig @@ -119,10 +116,9 @@ def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.configs.load_config import LoadConfig from atom.config import Config, ParallelConfig, CompilationConfig - # sglang has no global config variable like vllm, - # so here construct the server args from sys.argv passed by users - # this is the only way to get full arguments - server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + # sglang's ModelRunner already parsed and stored ServerArgs globally + # before OOT model loading, so we can retrieve it directly. + server_args = get_global_server_args() sgl_model_config = SglangModelConfig.from_server_args(server_args) sgl_model_opt_config = ModelOptConfig( @@ -214,7 +210,6 @@ def generate_atom_config_for_plugin_mode(config: Any = None): """ logger.info("Generate atom config for plugin mode from passed config") - atom_config = None from atom.plugin import is_vllm, is_sglang from atom.config import set_current_atom_config diff --git a/atom/plugin/register.py b/atom/plugin/register.py index af2427fbf..8ca19e7b1 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -3,6 +3,7 @@ from atom.models.qwen3 import Qwen3ForCausalLM from atom.models.qwen3_moe import Qwen3MoeForCausalLM from atom.models.glm4_moe import Glm4MoeForCausalLM +from atom.models.deepseek_v2 import DeepseekV3ForCausalLM from atom.config import Config from atom.plugin.prepare import is_vllm, is_sglang @@ -12,11 +13,16 @@ "Qwen3ForCausalLM": Qwen3ForCausalLM, "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "Glm4MoeForCausalLM": Glm4MoeForCausalLM, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, } def _register_custom_attention_to_sglang() -> None: + """Override sglang's built-in "aiter" attention backend with ATOM's implementation. + sglang only accepts pre-registered backend names, so we reuse the "aiter" + name to inject ATOMAttnBackendForSgl without modifying sglang source. + """ from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) @@ -28,9 +34,9 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from atom.plugin.sglang.sgl_attn_backend import ATOMAttnBackendForSgl - return AiterAttnBackend(runner) + return ATOMAttnBackendForSgl(runner) def register_ops_to_sglang(atom_config: Config) -> None: @@ -41,8 +47,10 @@ def register_ops_to_sglang(atom_config: Config) -> None: def set_attn_cls() -> None: - """ - Set the attention class for constructing the model based on the framework + """Swap ``atom.model_ops.Attention`` to the framework-appropriate class. + + ATOM models reference ``ops.Attention`` generically; this function binds + it to PagedAttention (vLLM) or RadixAttention (sglang) at plugin init time. """ import atom.model_ops as ops diff --git a/atom/plugin/sglang/__init__.py b/atom/plugin/sglang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/sglang/oot/base_model_wrapper.py b/atom/plugin/sglang/oot/base_model_wrapper.py new file mode 100644 index 000000000..97eff70d4 --- /dev/null +++ b/atom/plugin/sglang/oot/base_model_wrapper.py @@ -0,0 +1,105 @@ +"""ATOM model wrappers for SGLang external model loading (OOT). + +Registers model architecture classes via SGLANG_EXTERNAL_MODEL_PACKAGE, +replacing sglang's built-in implementations with ATOM-optimized versions. + +To add a new model, append its architecture class name to _MODEL_NAMES. +""" + +import logging +from typing import Iterable, Optional, Tuple, Union + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +logger = logging.getLogger("atom.plugin.sglang.oot") + +_MODEL_NAMES = [ + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3ForCausalLM", +] + + +class _AtomCausalLMBaseForSglangOOT(nn.Module): + """Base ATOM model wrapper conforming to sglang's model interface. + + Delegates model creation and weight loading to ATOM's plugin system, + while providing the forward signature and LogitsProcessorOutput return + type that sglang expects. + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + import atom + + self.model = atom.prepare_model(config=config, engine="sglang") + if self.model is None: + model_arch = getattr(config, "architectures", ["unknown"])[0] + raise ValueError( + f"ATOM failed to create model for architecture {model_arch}" + ) + + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=input_embeds, + forward_batch=forward_batch, + get_embedding=get_embedding, + pp_proxy_tensors=pp_proxy_tensors, + ) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + from atom.model_loader.loader import load_model_in_plugin_mode + + return load_model_in_plugin_mode( + model=self.model, config=self.model.atom_config, prefix="model." + ) + + +EntryClass = [] +for _name in _MODEL_NAMES: + _cls = type(_name, (_AtomCausalLMBaseForSglangOOT,), {}) + globals()[_name] = _cls + EntryClass.append(_cls) diff --git a/atom/plugin/sglang/sgl_attention_mla.py b/atom/plugin/sglang/sgl_attention_mla.py new file mode 100644 index 000000000..c54bce421 --- /dev/null +++ b/atom/plugin/sglang/sgl_attention_mla.py @@ -0,0 +1,645 @@ +"""Sglang-specific MLA forward and weight processing for DeepseekV2/V3. + +This module is lazily imported from deepseek_v2.py only when running in sglang +plugin mode (``is_sglang() == True``). Keeping all sglang-dependent imports +here avoids crashing when sglang is not installed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple, Optional + +import torch +from aiter.dist.parallel_state import get_tensor_model_parallel_world_size, get_tp_group +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla + +# sglang imports +from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context +from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.models.deepseek_common.utils import ( + _use_aiter_gfx95, + _is_hip, + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_npu, + awq_dequantize_func, +) +from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( + batched_gemm_afp4wfp4_pre_quant, +) +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + fp8_dtype, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) +from sglang.srt.utils import bind_or_assign, get_bool_env_var + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +# bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) +if _is_cuda: + from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 + from sglang.srt.utils.custom_op import register_custom_op + + @register_custom_op(mutates_args=["out"]) + def _bmm_fp8_op( + A: torch.Tensor, + B: torch.Tensor, + out: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + ) -> None: + _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _bmm_fp8_op(A, B, out, A_scale, B_scale) + return out + +else: + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") + + +# NamedTuple for prepare → core data flow +class SglPrepareResult(NamedTuple): + q_pe: torch.Tensor + k_pe: torch.Tensor + q_nope_out: torch.Tensor + k_nope: torch.Tensor + forward_batch: Any + zero_allocator: Any + positions: torch.Tensor + topk_indices: Optional[torch.Tensor] + llama_4_scaling: Optional[Any] + + +# Init helpers +def init_sgl_attrs(attn: DeepseekV2MLAAttention, config) -> None: + """Initialise sglang-only attributes on DeepseekV2MLAAttention.""" + from sglang.srt.configs.model_config import is_deepseek_nsa + + attn.use_nsa = is_deepseek_nsa(config) + attn.use_deep_gemm_bmm = False + attn.alt_stream = None + attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 + attn.w_kc, attn.w_vc = None, None + attn.w_scale = None + attn.w_scale_k = None + attn.w_scale_v = None + + +# Absorbed batched-matmul (shared by prepare and core) +def mla_absorbed_bmm( + attn: DeepseekV2MLAAttention, + inp: torch.Tensor, + weight: torch.Tensor, + weight_scale: Optional[torch.Tensor], + weight_scale_k: Optional[torch.Tensor], + out_dim: int, +) -> torch.Tensor: + """Batched matmul for MLA absorbed weights (w_kc / w_vc). + + Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. + inp: (num_tokens, num_heads, in_dim) — token-major + Returns: (num_tokens, num_heads, out_dim) — token-major + """ + if attn.use_deep_gemm_bmm: + from sglang.srt.layers import deep_gemm_wrapper + + val, scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) + ) + out = inp.new_empty((attn.num_local_heads, aligned_m, out_dim)) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (val, scale), + (weight, weight_scale_k), + out, + masked_m, + expected_m, + ) + return out[:, :expected_m, :].transpose(0, 1) + + if _is_hip: + if _use_aiter_gfx95 and weight.dtype == torch.uint8: + x = inp.transpose(0, 1) + out = torch.empty( + x.shape[0], + x.shape[1], + weight.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + weight.transpose(-2, -1), + weight_scale_k.transpose(-2, -1), + torch.bfloat16, + out, + ) + return out.transpose(0, 1) + + if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz + ): + out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=inp, + WQ=weight.transpose(-1, -2), + w_scale=weight_scale, + group_size=128, + YQ=None, + transpose_bm=False, + transpose_bm_in=True, + dtype=torch.bfloat16, + ) + return out.transpose(0, 1) + + out = torch.bmm( + inp.to(torch.bfloat16).transpose(0, 1), + weight.to(torch.bfloat16) * weight_scale, + ) + return out.transpose(0, 1) + + # CUDA fp8 path + if weight.dtype == torch.float8_e4m3fn: + val, scale = per_tensor_quant_mla_fp8( + inp.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=inp.device), + ) + out = bmm_fp8(val, weight, scale, weight_scale, torch.bfloat16) + return out.transpose(0, 1) + + # bf16 fallback + return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) + + +# Forward: prepare → core +def forward_sgl_prepare( + attn: DeepseekV2MLAAttention, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs, +) -> SglPrepareResult: + """Prepare QKV for sglang MLA attention (adapted from sglang forward_absorb_prepare).""" + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + + forward_batch = model_kwargs.get("forward_batch", None) + zero_allocator = model_kwargs.get("zero_allocator", None) + llama_4_scaling = model_kwargs.get("llama_4_scaling", None) + q_lora = None + topk_indices = None + + if attn.q_lora_rank is not None: + q, latent_cache = ( + get_attn_tp_context() + .fetch_qkv_latent() + .split( + [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], + dim=-1, + ) + ) + + if q.shape[0] != positions.shape[0] and get_tensor_model_parallel_world_size() > 1: + qkv_lora = torch.cat([q, latent_cache], dim=-1) + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] < positions.shape[0]: + raise RuntimeError( + f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {positions.shape[0]}" + ) + qkv_lora = qkv_lora[: positions.shape[0]] + q, latent_cache = torch.split( + qkv_lora, + [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], + dim=-1, + ) + + k_nope = latent_cache[..., : attn.kv_lora_rank] + + # overlap qk norm + if attn.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + attn.alt_stream.wait_stream(current_stream) + q = attn.q_a_layernorm(q) + with torch.cuda.stream(attn.alt_stream): + k_nope = attn.kv_a_layernorm(k_nope) + current_stream.wait_stream(attn.alt_stream) + else: + q = attn.q_a_layernorm(q) + k_nope = attn.kv_a_layernorm(k_nope) + + if attn.use_nsa: + if q_lora is None: + q_lora = q + + # overlap q_b_proj and indexer during decode + if ( + attn.alt_stream is not None + and get_is_capture_mode() + and forward_batch.forward_mode.is_decode_or_idle() + and q_lora is not None + ): + current_stream = torch.cuda.current_stream() + attn.alt_stream.wait_stream(current_stream) + with torch.cuda.stream(attn.alt_stream): + k_nope = k_nope.unsqueeze(1) + q = attn.q_b_proj(q).view(-1, attn.num_local_heads, attn.qk_head_dim) + topk_indices = attn.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=attn.layer_num, + ) + current_stream.wait_stream(attn.alt_stream) + else: + k_nope = k_nope.unsqueeze(1) + q = attn.q_b_proj(q).view(-1, attn.num_local_heads, attn.qk_head_dim) + if q_lora is not None: + topk_indices = attn.indexer( + x=hidden_states, + q_lora=q_lora, + positions=positions, + forward_batch=forward_batch, + layer_id=attn.layer_num, + ) + else: + q = attn.q_proj(hidden_states).view(-1, attn.num_local_heads, attn.qk_head_dim) + latent_cache = attn.kv_a_proj_with_mqa(hidden_states)[0] + k_nope = latent_cache[..., : attn.kv_lora_rank] + k_nope = attn.kv_a_layernorm(k_nope).unsqueeze(1) + + q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) + k_pe = latent_cache[..., attn.kv_lora_rank :].unsqueeze(1) + + q_nope_out = mla_absorbed_bmm( + attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank + ) + + if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + k_nope, k_pe = attn.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + + return SglPrepareResult( + q_pe=q_pe, + k_pe=k_pe, + q_nope_out=q_nope_out, + k_nope=k_nope, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + positions=positions, + topk_indices=topk_indices, + llama_4_scaling=llama_4_scaling, + ) + + +def forward_sgl_core( + attn: DeepseekV2MLAAttention, + prepared: SglPrepareResult, +) -> torch.Tensor: + """Core MLA attention computation for sglang (adapted from sglang forward_absorb_core).""" + save_kv_cache = True + + if attn.use_fused_qk_rope_concat_and_cache_mla: + cos = attn.rotary_emb.cos_cache + sin = attn.rotary_emb.sin_cache + kv_cache = prepared.forward_batch.token_to_kv_pool.get_key_buffer(attn.layer_num) + k_scale = attn.mla_attn.attn.k_scale + + q, _, k_pe_roped, _ = fused_qk_rope_cat_and_cache_mla( + prepared.q_nope_out, + prepared.q_pe, + prepared.k_nope, + prepared.k_pe, + kv_cache, + prepared.forward_batch.out_cache_loc, + prepared.positions, + cos, + sin, + k_scale, + attn.rotary_emb.is_neox_style, + q_out_dtype=prepared.q_nope_out.dtype, + ) + k = torch.cat([prepared.k_nope, k_pe_roped], dim=-1) + save_kv_cache = False + else: + q = torch.cat([prepared.q_nope_out, prepared.q_pe], dim=-1) + k = torch.cat([prepared.k_nope, prepared.k_pe], dim=-1) + + if prepared.llama_4_scaling is not None: + q = q * prepared.llama_4_scaling + + extra_kwargs = {} + if prepared.topk_indices is not None: + extra_kwargs["topk_indices"] = prepared.topk_indices + + attn_output = attn.mla_attn( + q, + k, + prepared.k_nope, + forward_batch=prepared.forward_batch, + save_kv_cache=save_kv_cache, + **extra_kwargs, + ) + attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) + + # up-proj by w_vc + attn_bmm_output = mla_absorbed_bmm( + attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim + ).flatten(1, 2) + + return attn.o_proj(attn_bmm_output) + + +def prepare_qkv_latent( + attn: DeepseekV2MLAAttention, + hidden_states: torch.Tensor, + forward_batch, +) -> torch.Tensor: + """Prepare QKV latent tensor for the sglang communicator.""" + assert attn.q_lora_rank is not None + hidden_states_scale = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_scale = hidden_states + qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) + + # Fallback: when communicator does not enable input_scattered gather, + # force qkv latent token dimension to align with positions. + expected_tokens = 0 + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + expected_tokens = int(forward_batch.positions.shape[0]) + if expected_tokens <= 0: + expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) + + if ( + expected_tokens > 0 + and qkv_lora.shape[0] != expected_tokens + and get_tensor_model_parallel_world_size() > 1 + ): + qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) + if qkv_lora.shape[0] > expected_tokens: + qkv_lora = qkv_lora[:expected_tokens] + elif qkv_lora.shape[0] < expected_tokens: + raise RuntimeError( + f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " + f"expected {expected_tokens}" + ) + return qkv_lora + + +# Top-level forward entry point +def forward_sgl_plugin_mode( + attn: DeepseekV2MLAAttention, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **model_kwargs, +) -> torch.Tensor: + """Full MLA forward in sglang plugin mode.""" + forward_batch = model_kwargs.get("forward_batch", None) + if forward_batch is None: + raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") + + attn_tp_context = get_attn_tp_context() + with attn_tp_context.maybe_input_scattered(forward_batch): + if attn.q_lora_rank is not None: + attn_tp_context.set_attn_inputs( + AttentionInputs( + hidden_states, + forward_batch, + lambda hs, fb: prepare_qkv_latent(attn, hs, fb), + ) + ) + prepared = forward_sgl_prepare(attn, positions, hidden_states, **model_kwargs) + return forward_sgl_core(attn, prepared) + + +# Weight post-processing: decomposed into sub-functions +def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: + """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" + if hasattr(attn.kv_b_proj, "qweight"): + awq_dequant = awq_dequantize_func() + if awq_dequant is None: + raise ValueError("AWQ dequantize function is not supported for current device") + w = awq_dequant( + attn.kv_b_proj.qweight, + attn.kv_b_proj.scales, + attn.kv_b_proj.qzeros, + ).T + else: + w = attn.kv_b_proj.weight + + # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes. + # View-cast back to fn so the normalize path works correctly. + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: + w = w.view(torch.float8_e4m3fn) + + return w + + +def _get_weight_block_size(attn: DeepseekV2MLAAttention) -> Optional[list[int]]: + """Derive weight_block_size from ATOM's quant_type system.""" + from aiter import QuantType as _AiterQuantType + + qt = getattr(attn.kv_b_proj, "quant_type", None) + if qt == _AiterQuantType.per_1x128: + return [128, 128] + elif qt == _AiterQuantType.per_1x32: + return [1, 32] + return None + + +def _process_fp8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: + """Process FP8 weights for kv_b_proj. + + Returns (w, use_deep_gemm_bmm, block_scale). + """ + from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz + from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + ) + from sglang.srt.layers.deep_gemm_wrapper import ENABLE_JIT_DEEPGEMM, DEEPGEMM_BLACKWELL + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 + + use_deep_gemm_bmm = False + block_scale = None + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") or hasattr( + attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + attn.kv_b_proj.weight_scale + if hasattr(attn.kv_b_proj, "weight_scale") + else attn.kv_b_proj.weight_scale_inv + ) + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=weight_scale, input_scale=None + ) + else: + weight = w + + if ( + should_deepgemm_weight_requant_ue8m0(weight_block_size=weight_block_size) + and getattr(weight_scale, "format_ue8m0", False) + ): + weight_scale = inverse_transform_scale_ue8m0( + weight_scale, mn=weight.shape[-2] + ) + + if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: + if ( + ENABLE_JIT_DEEPGEMM + and not DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, weight_scale, weight_block_size, torch.bfloat16 + ) + else: + w, scale = block_quant_to_tensor_quant(weight, weight_scale, weight_block_size) + attn.w_scale = scale + else: + if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=attn.kv_b_proj.weight_scale, input_scale=None + ) + else: + weight = w + weight_scale = attn.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + attn.w_scale = scale + + return w, use_deep_gemm_bmm, block_scale + + +def _process_int8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> torch.Tensor: + """Process INT8 weights for kv_b_proj.""" + from sglang.srt.layers.quantization.int8_utils import block_dequant as int8_block_dequant + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") + return int8_block_dequant( + w, attn.kv_b_proj.weight_scale_inv, weight_block_size + ).to(torch.bfloat16) + else: + return w.to(torch.bfloat16) * attn.kv_b_proj.weight_scale.to(torch.bfloat16) + + +def _split_and_assign_kc_vc( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + use_deep_gemm_bmm: bool, + block_scale: Optional[torch.Tensor], + weight_block_size: Optional[list[int]], +) -> None: + """Split weight into kc/vc and assign to attn.""" + from atom.model_ops.utils import quark_post_load_weights + + w_kc, w_vc = w.unflatten( + 0, (-1, attn.qk_nope_head_dim + attn.v_head_dim) + ).split([attn.qk_nope_head_dim, attn.v_head_dim], dim=1) + + # quark fp4 special path + quant_method = getattr(attn.kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + if ( + _use_aiter_gfx95 + and quant_config is not None + and quant_config.get_name() == "quark" + ): + w_kc, attn.w_scale_k, w_vc, attn.w_scale_v = quark_post_load_weights( + attn, w, "mxfp4" + ) + + if not use_deep_gemm_bmm: + attn.w_kc = bind_or_assign( + attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + attn.w_vc = bind_or_assign(attn.w_vc, w_vc) + + if hasattr(attn.kv_b_proj, "weight_scale") and attn.w_scale is None: + attn.w_scale = bind_or_assign(attn.w_scale, attn.kv_b_proj.weight_scale) + if _is_hip: + attn.w_scale *= 2.0 + + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + attn.w_kc = attn.w_kc.to(torch.bfloat16) * attn.w_scale + attn.w_vc = attn.w_vc.to(torch.bfloat16) * attn.w_scale + else: + num_tiles_k = attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + + attn.w_scale_k = bind_or_assign(attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()) + attn.w_scale_v = bind_or_assign(attn.w_scale_v, ws_vc.contiguous()) + attn.w_kc = bind_or_assign(attn.w_kc, w_kc.transpose(1, 2).contiguous()) + attn.w_vc = bind_or_assign(attn.w_vc, w_vc.contiguous()) + attn.use_deep_gemm_bmm = True + + +def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: + """Process kv_b_proj weights after loading for sglang MLA mode. + + Orchestrates reading, quantization handling, and splitting of + kv_b_proj into absorbed w_kc / w_vc weights. + """ + w = _read_kv_b_proj_weight(attn) + weight_block_size = _get_weight_block_size(attn) + + use_deep_gemm_bmm = False + block_scale = None + + # fp8 path + if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( + attn, w, weight_block_size + ) + + # int8 path + if w.dtype == torch.int8: + w = _process_int8_weight(attn, w, weight_block_size) + + # split and assign kc/vc + _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) diff --git a/atom/plugin/sglang/sgl_attn_backend.py b/atom/plugin/sglang/sgl_attn_backend.py new file mode 100644 index 000000000..8bd49661c --- /dev/null +++ b/atom/plugin/sglang/sgl_attn_backend.py @@ -0,0 +1,1240 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +import sglang.srt.layers.attention.aiter_backend as _sglang_aiter +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import get_bool_env_var + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + mla_decode_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +# MLA prefill kernels - imported separately to avoid breaking the main aiter imports +mla_prefill_ps_asm_fwd = None +mla_reduce_v1 = None +mla_prefill_fwd = None +try: + from aiter import mla_prefill_ps_asm_fwd +except ImportError: + pass +try: + from aiter import mla_reduce_v1 +except ImportError: + pass +try: + from aiter.mla import mla_prefill_fwd + from aiter.mla import mla_decode_fwd +except ImportError: + pass + +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + + +@dataclass +class ForwardMetadata: + """Per-batch metadata consumed by ATOM's attention kernels (pa_fwd_asm, mla_decode_fwd, etc.).""" + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # mla + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None + fp8_prefill_kv_indices: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + + + +class ATOMAttnBackendForSgl(AiterAttnBackend): + """ATOM's custom attention backend for sglang plugin mode. + + Extends sglang's AiterAttnBackend with ATOM-specific optimisations: + page-table management, pa_persistent_fwd decode path, and MLA + prefill kernels (fp8, decompress, absorbed). Registered to sglang + via atom.plugin.register._register_custom_attention_to_sglang(). + """ + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + # assert not self.use_mla, "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + if forward_batch.forward_mode.is_decode_or_idle(): + self._init_forward_metadata_decode(forward_batch) + else: + self._init_forward_metadata_extend(forward_batch) + self._fixup_page_table(forward_batch) + + def _init_forward_metadata_decode(self, forward_batch: ForwardBatch): + bs = forward_batch.batch_size + spec_info = forward_batch.spec_info + + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + self._init_decode_mla(bs, kv_indptr, kv_indices) + else: + self._init_decode_mha(bs, kv_indptr, kv_indices, forward_batch) + + def _init_decode_mla(self, bs, kv_indptr, kv_indices): + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) + kv_last_page_len = self.kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + ( + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_q_len, bs) + num_kv_splits = self.max_split_per_batch + self.make_mla_meta_data( + qo_indptr, kv_indptr, kv_last_page_len, + work_metadata, work_info_set, work_indptr, + reduce_indptr, reduce_final_map, reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, qo_indptr, kv_last_page_len, + max_q_len, None, None, None, + work_metadata=work_metadata, work_info_set=work_info_set, + work_indptr=work_indptr, reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + + def _init_decode_mha(self, bs, kv_indptr, kv_indices, forward_batch): + if self.decode_using_pa_ps: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(forward_batch.seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, None, None, 1, None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + kv_indptr, kv_indices, None, None, 1, None, + page_table, forward_batch.seq_lens, + ) + + def _init_forward_metadata_extend(self, forward_batch: ForwardBatch): + bs = forward_batch.batch_size + + if self.use_mla: + self._init_extend_mla(bs, forward_batch) + else: + self._init_extend_mha(bs, forward_batch) + + def _init_extend_mla(self, bs, forward_batch): + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_seq_lens, + forward_batch.extend_seq_lens.max().item(), + forward_batch.seq_lens.max().item(), + spec_info=None, + ) + + max_q_len = self.mla_indices_updater_prefill.max_q_len + qo_indptr = self.mla_indices_updater_prefill.qo_indptr + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + fp8_prefill_kv_indices = None + + from sglang.srt.utils import is_gfx95_supported + _use_fp8_prefill_attn = ( + get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") + and is_gfx95_supported() + ) + if _use_fp8_prefill_attn: + tile_q = 256 + qlen_granularity = tile_q // (self.num_head // self.num_kv_head) + ( + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + ) = self.make_mla_prefill_ps_meta_data_buffer( + bs, max_q_len, qlen_granularity + ) + self.make_mla_prefill_ps_meta_data( + qo_indptr, qo_indptr, forward_batch.seq_lens, + work_metadata, work_indptr, work_info_set, + reduce_indptr, reduce_final_map, reduce_partial_map, + is_causal=True, + ) + total_s = int(forward_batch.extend_seq_lens.sum()) + fp8_prefill_kv_indices = torch.arange( + total_s, device=self.device, dtype=torch.int32 + ) + + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + qo_indptr, + self.kv_last_page_len[:bs], + max_q_len, + self.mla_indices_updater_prefill.max_kv_len, + None, None, + work_metadata=work_metadata, work_info_set=work_info_set, + work_indptr=work_indptr, reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, reduce_partial_map=reduce_partial_map, + fp8_prefill_kv_indices=fp8_prefill_kv_indices, + ) + + def _init_extend_mha(self, bs, forward_batch): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + forward_batch.extend_prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[: bs + 1], + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + def _fixup_page_table(self, forward_batch: ForwardBatch): + """Post-process page_table for non-MLA extend mode.""" + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size + ) + + def _ensure_buffer(self, name, size, dtype, zero=True): + """Allocate or reuse a pa_metadata buffer, growing if needed.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + size_val = size[0] if isinstance(size, (tuple, list)) else size + buf = self.pa_metadata_buffers.get(name) + needs_alloc = ( + buf is None + or buf.shape[0] < size_val + or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) + ) + if needs_alloc: + factory = torch.zeros if zero else torch.empty + self.pa_metadata_buffers[name] = factory(size, dtype=dtype, device=self.device) + elif zero: + self.pa_metadata_buffers[name].zero_() + + def _allocate_pa_metadata_buffers(self, buffer_specs): + """Allocate or reuse pa_metadata buffers. + + Args: + buffer_specs: sequence of ((size, dtype), ...) tuples from get_pa_metadata_info_v1, + in order: work_metadata_ptrs, work_indptr, work_info, + reduce_indptr, reduce_final_map, reduce_partial_map. + """ + names = [ + "work_metadata_ptrs", "work_indptr", "work_info", + "reduce_indptr", "reduce_final_map", "reduce_partial_map", + ] + zero_flags = [False, True, True, True, True, True] + for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): + self._ensure_buffer(name, size, dtype, zero=zero) + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + buffer_specs = get_pa_metadata_info_v1(batch_size, self.num_kv_head) + self._allocate_pa_metadata_buffers(buffer_specs) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int, device=self.device) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, + ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + if self.use_mla and _sglang_aiter._use_mla_ps_kernel: + max_seqlen_qo = 1 + ( + self.work_metadata, + self.work_indptr, + self.work_info_set, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + ) = self.make_mla_decode_meta_data_buffer(max_seqlen_qo, max_bs) + elif self.use_mla: + self.work_metadata = None + self.work_indptr = None + self.work_info_set = None + self.reduce_indptr = None + self.reduce_final_map = None + self.reduce_partial_map = None + + if self.decode_using_pa_ps and not self.use_mla: + buffer_specs = get_pa_metadata_info_v1(max_bs, self.num_kv_head) + self._allocate_pa_metadata_buffers(buffer_specs) + + def _init_mla_cuda_graph_metadata(self, bs, req_pool_indices, seq_lens): + """Shared MLA decode metadata setup for CUDA graph capture/replay.""" + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = 1 + + work_metadata = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + num_kv_splits = None + + if _sglang_aiter._use_mla_ps_kernel: + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_q_len, + None, + None, + None, + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if not forward_mode.is_decode_or_idle(): + raise ValueError(f"Invalid mode: {forward_mode=}") + + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, None, None, None, 1, None, page_table, seq_lens_persistent, + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if not forward_mode.is_decode_or_idle(): + raise ValueError("Invalid forward mode") + + if self.use_mla: + self._init_mla_cuda_graph_metadata(bs, req_pool_indices, seq_lens) + else: + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = (seq_lens_cpu.max().item() + self.page_size - 1) // self.page_size + 1 + page_table = self.req_to_token[req_pool_indices[:, None], self.strided_indices[:max_seq_pages][None, :],] + page_table_persistent[:bs, :max_seq_pages].copy_(page_table // self.page_size, non_blocking=True) + + self.forward_metadata = ForwardMetadata( + None, None, None, None, 1, None, + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, k, v, k_buffer, v_buffer, + layer.k_scale, layer.v_scale, self.page_size, + ) + + if self.use_mla: + return self._forward_extend_mla(q, k, v, layer, forward_batch) + else: + return self._forward_extend_mha(q, k, v, layer, forward_batch) + + def _forward_extend_mha(self, q, k, v, layer, forward_batch): + """Non-MLA extend path: standard MHA with flash_attn_varlen_func.""" + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def _forward_extend_mla(self, q, k, v, layer, forward_batch): + """MLA extend path: ported from sglang aiter_backend forward_extend MLA logic.""" + max_q_len = self.forward_metadata.max_q_len + max_kv_len = self.forward_metadata.max_kv_len + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + qo_indptr = self.forward_metadata.qo_indptr + + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + kv_lora_rank = V_Buffer.shape[-1] + qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank + qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim + + assert len(q.shape) == 3 + assert len(k.shape) == 3 + assert len(v.shape) == 3 + + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return self._forward_extend_mla_normal( + q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ) + elif ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + return self._forward_extend_mla_speculative( + q, layer, K_Buffer, qo_indptr, + ) + else: + raise ValueError( + f"Invalid forward mode for MLA extend: {forward_batch.forward_mode=}" + ) + + def _forward_extend_mla_normal( + self, q, k, v, layer, forward_batch, + K_Buffer, V_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ): + """Normal MLA extend (not target_verify, not draft_extend).""" + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + + if kv_indices.shape[0] == 0 or extend_no_prefix: + return self._extend_mla_no_prefix( + q, k, v, layer, kv_lora_rank, qk_rope_head_dim, + max_q_len, qo_indptr, + ) + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + return self._extend_mla_decompress_prefix( + q, layer, forward_batch, K_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ) + else: + return self._extend_mla_absorbed_prefix( + q, layer, K_Buffer, kv_indptr, kv_indices, qo_indptr, + ) + + def _extend_mla_no_prefix( + self, q, k, v, layer, kv_lora_rank, qk_rope_head_dim, + max_q_len, qo_indptr, + ): + """No-prefix prefill: FP8 kernel, mla_prefill_fwd, or flash_attn fallback.""" + if self.forward_metadata.fp8_prefill_kv_indices is not None: + return self._extend_mla_fp8_prefill(q, k, v, layer, max_q_len, qo_indptr) + + if layer.qk_head_dim == (kv_lora_rank + qk_rope_head_dim) and mla_prefill_fwd is not None: + # Absorbed MLA: head_dim (576) exceeds CK limit (256), + # use mla_prefill_fwd which natively supports large MLA head dims. + if layer.qk_head_dim != layer.v_head_dim: + output = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + output = torch.empty_like(q) + total_s = q.shape[0] + temp_kv_indices = torch.arange(total_s, device=q.device, dtype=torch.int32) + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.view(-1, 1, 1, layer.qk_head_dim), + output.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, qo_indptr, temp_kv_indices, + self.forward_metadata.kv_last_page_len, + max_q_len, layer.scaling, layer.logit_cap, + ) + return output + + return flash_attn_varlen_func( + q, k, v, qo_indptr, qo_indptr, max_q_len, max_q_len, + softmax_scale=layer.scaling, causal=True, + ) + + def _extend_mla_fp8_prefill(self, q, k, v, layer, max_q_len, qo_indptr): + """FP8 prefill path using mla_prefill_ps_asm_fwd + mla_reduce_v1.""" + total_s = q.shape[0] + nhead = layer.tp_q_head_num + v_head_dim = layer.v_head_dim + md = self.forward_metadata + + if q.dtype != dtypes.fp8: + q = q.to(dtypes.fp8) + if k.dtype != dtypes.fp8: + k = k.to(dtypes.fp8) + if v.dtype != dtypes.fp8: + v = v.to(dtypes.fp8) + one_scale = torch.ones((), dtype=torch.float32, device=q.device) + + tile_q = 256 + logits = torch.empty( + (md.reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), + dtype=torch.float32, device=q.device, + ) + attn_lse = torch.empty( + (md.reduce_partial_map.size(0) * tile_q, nhead), + dtype=torch.float32, device=q.device, + ) + final_lse = torch.empty((total_s, nhead), dtype=torch.float32, device=q.device) + output = q.new_empty((total_s, nhead, v_head_dim), dtype=self.input_dtype) + + mla_prefill_ps_asm_fwd( + q, k, v, qo_indptr, qo_indptr, + md.fp8_prefill_kv_indices, md.work_indptr, md.work_info_set, + max_q_len, layer.scaling, True, + logits, attn_lse, output, one_scale, one_scale, one_scale, + ) + mla_reduce_v1( + logits, attn_lse, md.reduce_indptr, md.reduce_final_map, + md.reduce_partial_map, tile_q, output, final_lse, + ) + return output + + def _extend_mla_decompress_prefix( + self, q, layer, forward_batch, K_Buffer, + kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, + max_q_len, max_kv_len, kv_indptr, kv_indices, qo_indptr, + ): + """Has prefix, absorbed weights differ: decompress via kv_b_proj + flash_attn.""" + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split(K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1) + + if self.kv_cache_dtype == dtypes.fp8: + dtype = q.dtype + kvc = kvc.to(dtype) + k_pe = k_pe.to(dtype) + + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ + k_prefix, + torch.broadcast_to( + k_pe, (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) + + assert forward_batch.extend_prefix_lens.shape == forward_batch.extend_seq_lens.shape + + return flash_attn_varlen_func( + q, k_prefix, v_prefix, qo_indptr, kv_indptr, + max_q_len, max_kv_len, softmax_scale=layer.scaling, causal=True, + ) + + def _extend_mla_absorbed_prefix( + self, q, layer, K_Buffer, kv_indptr, kv_indices, qo_indptr, + ): + """Has prefix, qk_head_dim == kv_lora_rank + qk_rope_head_dim: mla_prefill_fwd.""" + k_selected = torch.index_select(K_Buffer, 0, kv_indices) + if k_selected.dtype != q.dtype: + k_selected = k_selected.to(q.dtype) + compact_kv_indices = torch.arange( + k_selected.shape[0], device=q.device, dtype=torch.int32 + ) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + mla_prefill_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_selected.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + qo_indptr, kv_indptr, compact_kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_q_len, + layer.scaling, layer.logit_cap, + ) + return o + + def _call_mla_decode_fwd(self, q, k_buffer, o, layer): + """Common mla_decode_fwd invocation shared across decode/extend paths.""" + md = self.forward_metadata + mla_decode_fwd( + q, k_buffer.view(-1, 1, 1, layer.qk_head_dim), o, + md.qo_indptr, md.kv_indptr, md.kv_indices, + md.kv_last_page_len, md.max_q_len, + sm_scale=layer.scaling, logit_cap=layer.logit_cap, + work_meta_data=md.work_metadata, + work_indptr=md.work_indptr, + work_info_set=md.work_info_set, + reduce_indptr=md.reduce_indptr, + reduce_final_map=md.reduce_final_map, + reduce_partial_map=md.reduce_partial_map, + q_scale=layer.k_scale, kv_scale=layer.k_scale, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + num_kv_splits=md.num_kv_splits, + ) + + def _forward_extend_mla_speculative(self, q, layer, K_Buffer, qo_indptr): + """MLA speculative path (target_verify / draft_extend).""" + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), + dtype=self.input_dtype, + ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + return o + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + batch_size = q.shape[0] + head_dim_out = layer.v_head_dim if layer.qk_head_dim != layer.v_head_dim else layer.head_dim + + if self.use_mla: + o = q.new_empty( + (batch_size, layer.tp_q_head_num * head_dim_out), dtype=self.input_dtype, + ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + self._call_mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer, + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + layer, + ) + return o + + # Non-MLA decode paths + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, k, v, + k_buffer, v_buffer, layer.k_scale, layer.v_scale, self.page_size, + ) + + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + x = 16 // k_buffer.element_size() + new_key_cache = k_buffer.view(num_blocks, num_kv_heads, head_size // x, block_size, x) + new_value_cache = v_buffer.view(num_blocks, num_kv_heads, block_size // x, head_size, x) + + if self.decode_using_pa_ps: + total_tokens = num_blocks * block_size + q_3d = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + pa_persistent_fwd( + Q=q_3d, K=new_key_cache, V=new_value_cache, output=o, + max_qlen=self.forward_metadata.pa_metadata_max_qlen, + qo_indptr=self.forward_metadata.pa_metadata_qo_indptr, + kv_indptr=self.forward_metadata.pa_metadata_pages_kv_indptr, + kv_indices=self.forward_metadata.pa_metadata_kv_indices, + context_lens=self.forward_metadata.pa_metadata_context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=self.k_qscale[:, :total_tokens], + V_QScale=self.v_qscale[:, :total_tokens], + softmax_scale=layer.scaling, mask=1, + ) + else: + q_3d = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q_3d, K=new_key_cache, V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, V_QScale=self.v_scale, out_=o, + ) + + return o.view(-1, layer.tp_q_head_num * head_dim_out) diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 6f2fac5e0..07b7e10fb 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -88,6 +88,7 @@ "ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD": lambda: int( os.getenv("ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD", "1024") ), + "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", }