diff --git a/atom/config.py b/atom/config.py index e273fb73c..bb356ca39 100644 --- a/atom/config.py +++ b/atom/config.py @@ -621,22 +621,29 @@ def __post_init__(self): # assert os.path.isdir(self.model) assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = get_hf_config(self.model) + if is_plugin_mode(): + # plugin mode + assert ( + self.plugin_config is not None + ), "plugin_config is required in plugin mode" + self.hf_config = self.plugin_config.model_config.hf_config + else: + self.hf_config = get_hf_config(self.model) + + self.generation_config = get_generation_config(self.model) + if self.generation_config is not None: + if ( + eos_ids := getattr(self.generation_config, "eos_token_id", None) + ) is not None: + self.stop_token_ids = ( + [eos_ids] if isinstance(eos_ids, int) else eos_ids + ) if not hasattr(self.hf_config, "rope_parameters"): # Compatible with both transformers < 5 - rope_params = getattr(self.hf_config, "rope_scaling", {}) - if rope_params is None: - rope_params = {} - rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) - rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") + rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} + rope_params["rope_theta"] = self.hf_config.rope_theta + rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") self.hf_config.rope_parameters = rope_params - - self.generation_config = get_generation_config(self.model) - if self.generation_config is not None: - if ( - eos_ids := getattr(self.generation_config, "eos_token_id", None) - ) is not None: - self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids self.quant_config = get_quant_config(self.hf_config) hf_config_max_position_embeddings = getattr( self.hf_config, "max_position_embeddings", 8192 diff --git a/atom/model_ops/radix_attention.py b/atom/model_ops/radix_attention.py index 25388b384..34fdf0f90 100644 --- a/atom/model_ops/radix_attention.py +++ b/atom/model_ops/radix_attention.py @@ -9,6 +9,7 @@ from .base_attention import BaseAttention from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix +from atom.utils import envs class RadixAttention(BaseAttention): @@ -47,7 +48,6 @@ def __init__( prefix=prefix, **kwargs, ) - self.rotary_emb = rotary_emb if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention @@ -64,6 +64,8 @@ def __init__( raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" ) + # if True, save cache will be done in rope + self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM def forward_impl_plugin_mode( self, @@ -82,10 +84,14 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) assert forward_batch is not None, "forward_batch is required for sglang" - if self.rotary_emb is not None: - assert positions is not None, "positions is required for ROPE" - query, key = self.rotary_emb(positions, query, key) - return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) + # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM + return self.attn( + query, + key, + value, + forward_batch=forward_batch, + save_kv_cache=not self.use_aiter_rope_fused_qknorm, + ) else: raise NotImplementedError( "RadixAttention is only supported for plugin mode for sglang for now" diff --git a/atom/models/qwen3_moe.py b/atom/models/qwen3_moe.py index 9a0e1eba1..54421f008 100644 --- a/atom/models/qwen3_moe.py +++ b/atom/models/qwen3_moe.py @@ -5,7 +5,7 @@ from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size # from atom.model_ops.rotary_embedding import get_rope -from aiter.rotary_embedding import get_rope +from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg from atom.config import Config, QuantizationConfig from atom.model_ops.activation import SiluAndMul @@ -34,6 +34,7 @@ from atom.utils import envs from torch import nn from atom.model_loader.loader import load_model_in_plugin_mode +from atom.plugin.prepare import is_sglang # import torch.distributed as dist from transformers import PretrainedConfig @@ -42,6 +43,7 @@ ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION ) +ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM class Qwen3MoeMLP(nn.Module): @@ -228,6 +230,65 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.layer_num = layer_num + self.k_scale = torch.tensor([1.0], dtype=torch.float32) + self.v_scale = torch.tensor([1.0], dtype=torch.float32) + + def forward_sgl_plugin_mode( + self, + positions: torch.Tensor, + qkv: torch.Tensor, + **model_kwargs: dict[str, Any] | None, + ): + if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: + forward_batch = model_kwargs.get("forward_batch", None) + assert forward_batch is not None, "forward_batch is required for sglang" + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + self.layer_num + ) + block_size = 1024 # Default fallback + if hasattr(forward_batch, "attn_backend") and hasattr( + forward_batch.attn_backend, "page_size" + ): + block_size = forward_batch.attn_backend.page_size + elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( + forward_batch.token_to_kv_pool.allocator, "page_size" + ): + block_size = forward_batch.token_to_kv_pool.allocator.page_size + elif hasattr(forward_batch.token_to_kv_pool, "page_size"): + block_size = forward_batch.token_to_kv_pool.page_size + x = 16 // k_buffer.element_size() + aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( + kv_cache=(k_buffer, v_buffer), + cache_loc=forward_batch.out_cache_loc, + k_scale=self.k_scale, + v_scale=self.v_scale, + return_kv=True, + use_shuffle_layout=True, + block_size=block_size, + x=x, + ) + q, k, v = self.rotary_emb( + qkv, + self.q_norm.weight, + self.k_norm.weight, + positions, + self.num_heads, + self.num_kv_heads, + self.q_norm.eps, + fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, + ) + else: + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) + return attn_output def forward( self, @@ -245,13 +306,18 @@ def forward( query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv ) else: - # Add qk-norm - q = self.q_norm(q) - k = self.k_norm(k) + if is_sglang(): + attn_output = self.forward_sgl_plugin_mode( + positions, qkv, **model_kwargs + ) + else: + # Add qk-norm + q = self.q_norm(q) + k = self.k_norm(k) - attn_output = self.attn( - query=q, key=k, value=v, positions=positions, **model_kwargs - ) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, **model_kwargs + ) output = self.o_proj(attn_output) return output @@ -265,7 +331,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No self.hidden_size = config.hidden_size rope_params = config.rope_parameters rope_theta = rope_params["rope_theta"] - rope_scaling = rope_params + rope_scaling = None if rope_params["rope_type"] == "default" else rope_params kv_cache_dtype = atom_config.kv_cache_dtype max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix diff --git a/atom/plugin/attention_backend/__init__.py b/atom/plugin/attention_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/attention_backend/sgl_attn_backend.py b/atom/plugin/attention_backend/sgl_attn_backend.py new file mode 100644 index 000000000..87ed05af9 --- /dev/null +++ b/atom/plugin/attention_backend/sgl_attn_backend.py @@ -0,0 +1,1116 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInput + +try: + from aiter import ( + flash_attn_varlen_func, + dtypes, + get_pa_metadata_info_v1, + get_pa_metadata_v1, + pa_fwd_asm, + pa_persistent_fwd, + ) +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + ) + + +@dataclass +class ForwardMetadata: + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None + pa_metadata_tp_q_head_num: Optional[int] = None + # Prefill metadata for mha_batch_prefill_func (only used in prefill mode, non-MLA) + # prefill_pages_kv_indptr: Optional[torch.Tensor] = None + # prefill_kv_indices: Optional[torch.Tensor] = None + # prefill_kv_last_page_lens: Optional[torch.Tensor] = None + + +class ATOMAttnBackendForSgl(AiterAttnBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf) + + mapping = getattr( + model_runner.token_to_kv_pool, "full_attention_layer_id_mapping", None + ) + + if isinstance(mapping, dict) and mapping: + first_full_attn_id = next(iter(mapping.keys())) + else: + first_full_attn_id = 0 + + self.q_dtype = model_runner.dtype # Save q dtype for pa_metadata building + + assert ( + not self.use_mla + ), "MLA mode is not implemented yet in ATOMAttnBackendForSgl." + + # Pre-initialized qo_indptr for pa_persistent_fwd decode mode: [0, 1, 2, ..., max_bs] + # In decode mode, each sequence has 1 token, so this is always [0, 1, 2, ..., batch_size] + max_bs = model_runner.req_to_token_pool.size + self.pa_decode_qo_indptr = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=model_runner.device + ) + self.seq_lens = torch.zeros( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=model_runner.device, + ) + # Pre-compute strided indices for page_table construction (used in both CUDA Graph and non-CUDA Graph modes) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=model_runner.device + ) + + if not self.use_mla: + # Pre-allocate buffers for pa_persistent_fwd (used in both CUDA graph and non-CUDA graph modes) + max_num_blocks_per_seq = ( + self.max_context_len + self.page_size - 1 + ) // self.page_size + max_total_blocks = max_bs * max_num_blocks_per_seq + self.pa_kv_indices = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + # Pre-allocate pa_kv_indptr buffer (similar to self.kv_indptr, but dedicated for pa_persistent_fwd) + self.pa_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.device + ) + # Pre-initialized batch indices [0, 1, 2, ..., max_bs-1] for Triton kernel + self.pa_batch_indices = torch.arange( + 0, max_bs, dtype=torch.int32, device=self.device + ) + + # Pre-allocated descale tensors for FP8 attention (q, k, v all use scale=1.0) + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + self.pa_metadata_buffers = None + + k_buffer, _ = model_runner.token_to_kv_pool.get_kv_buffer(first_full_attn_id) + num_slots, num_kv_heads, _ = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + max_total_tokens = num_blocks * block_size + self.k_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.v_qscale = torch.ones( + num_kv_heads, max_total_tokens, dtype=torch.float32, device=self.device + ) + self.decode_using_pa_ps = self.page_size == 1024 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + # qo_indptr = None + # kv_last_page_len = None + # max_q_len = None + page_table = None + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + if self.decode_using_pa_ps: + # Non-MLA decode mode: use same logic as CUDA Graph mode for page_table construction + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + + # Common setup consistent with CUDA Graph mode (init_forward_metadata_replay_cuda_graph) + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_( + forward_batch.seq_lens, non_blocking=True + ) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + forward_batch.req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + else: + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + None, # qo_indptr not used in non-MLA mode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, + ( + page_table_persistent[:bs, :max_seq_pages] + if self.decode_using_pa_ps + else page_table + ), + ( + seq_lens_persistent[:bs] + if self.decode_using_pa_ps + else forward_batch.seq_lens + ), + ) + + # Build pa_metadata for pa_persistent_fwd + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + # return # Early return for non-MLA decode mode + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.use_mla: + raise NotImplementedError( + "MLA prefill mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + # Get page_table for mha_batch_prefill_func + page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : + ] + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.qo_indptr[ + : bs + 1 + ], # qo_indptr is set by indices_updater_prefill + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + None, + forward_batch.seq_lens, + ) + + if ( + forward_batch.forward_mode.is_extend() + and not self.use_mla + and self.forward_metadata.page_table is not None + ): + if self.page_size > 1: + seq_lens_cpu = forward_batch.seq_lens_cpu + if seq_lens_cpu is None: + seq_lens_cpu = forward_batch.seq_lens.cpu() + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[ + :, self.strided_indices[:max_seq_pages] + ] + // self.page_size + ) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_prefill(forward_batch.batch_size) + if ( + not self.decode_using_pa_ps + and self.page_size > 1 + and self.forward_metadata.page_table is not None + ): + self.forward_metadata.page_table = ( + self.forward_metadata.page_table[:, self.strided_indices] + // self.page_size + ) + + def _allocate_pa_metadata_buffers( + self, + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ): + """Allocate or reuse pa_metadata buffers.""" + if self.pa_metadata_buffers is None: + self.pa_metadata_buffers = {} + + def _get_size_val(size): + return size[0] if isinstance(size, tuple) else size + + # Allocate work_metadata_ptrs + size_val = _get_size_val(work_metadata_ptrs_size) + if ( + "work_metadata_ptrs" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_metadata_ptrs"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_metadata_ptrs"] = torch.empty( + work_metadata_ptrs_size, + dtype=work_metadata_ptrs_type, + device=self.device, + ) + + # Allocate work_indptr + size_val = _get_size_val(work_indptr_size) + if ( + "work_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["work_indptr"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_indptr"] = torch.zeros( + work_indptr_size, dtype=work_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_indptr"].zero_() + + # Allocate work_info + size_val = _get_size_val(work_info_size) + if ( + "work_info" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["work_info"].shape) < len(work_info_size) + or self.pa_metadata_buffers["work_info"].shape[0] < size_val + ): + self.pa_metadata_buffers["work_info"] = torch.zeros( + work_info_size, dtype=work_info_type, device=self.device + ) + else: + self.pa_metadata_buffers["work_info"].zero_() + + # Allocate reduce_indptr + size_val = _get_size_val(reduce_indptr_size) + if ( + "reduce_indptr" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_indptr"].shape[0] < size_val + ): + self.pa_metadata_buffers["reduce_indptr"] = torch.zeros( + reduce_indptr_size, dtype=reduce_indptr_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_indptr"].zero_() + + # Allocate reduce_final_map + size_val = _get_size_val(reduce_final_map_size) + if ( + "reduce_final_map" not in self.pa_metadata_buffers + or len(self.pa_metadata_buffers["reduce_final_map"].shape) + < len(reduce_final_map_size) + or self.pa_metadata_buffers["reduce_final_map"].shape[0] < size_val + ): + self.pa_metadata_buffers["reduce_final_map"] = torch.zeros( + reduce_final_map_size, dtype=reduce_final_map_type, device=self.device + ) + else: + self.pa_metadata_buffers["reduce_final_map"].zero_() + + # Allocate reduce_partial_map + reduce_partial_map_size_val = ( + reduce_partial_map_size + if isinstance(reduce_partial_map_size, int) + else reduce_partial_map_size[0] + ) + if ( + "reduce_partial_map" not in self.pa_metadata_buffers + or self.pa_metadata_buffers["reduce_partial_map"].shape[0] + < reduce_partial_map_size_val + ): + self.pa_metadata_buffers["reduce_partial_map"] = torch.zeros( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=self.device, + ) + else: + self.pa_metadata_buffers["reduce_partial_map"].zero_() + + def _build_pa_metadata_for_decode( + self, + batch_size: int, + tp_q_head_num: Optional[int] = None, + ): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode. + + This method prepares all metadata buffers needed for pa_persistent_fwd kernel. + The metadata can be reused across multiple layers in the same forward pass. + + Args: + batch_size: Batch size for the current forward pass + tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. + """ + max_qlen = 1 + + # Use provided tp_q_head_num or default to self.num_head + if tp_q_head_num is None: + tp_q_head_num = self.num_head + + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + batch_size, + self.num_kv_head, + ) + # Allocate metadata buffers with reuse optimization for multi-layer forward passes + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] + + # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) + # Note: kv_lens comes from self.seq_lens which is already int32 + context_lens = self.forward_metadata.kv_lens + + kernel_block_size = self.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync + # page_table shape: [batch_size, max_num_blocks_per_seq] + # Note: page_table comes from self.page_table which is already int32 and always set before this call + page_table = self.forward_metadata.page_table + + # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr + kv_indices = self.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // self.num_kv_head, + num_heads_k=self.num_kv_head, + is_causal=True, + work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + # Store computed values in ForwardMetadata for reuse in forward_decode + self.forward_metadata.pa_metadata_qo_indptr = qo_indptr + self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + self.forward_metadata.pa_metadata_kv_indices = kv_indices + self.forward_metadata.pa_metadata_context_lens = context_lens + self.forward_metadata.pa_metadata_max_qlen = max_qlen + self.forward_metadata.pa_metadata_tp_q_head_num = tp_q_head_num + + def _build_pa_metadata_for_prefill(self, batch_size: int): + """Build metadata for mha_batch_prefill_func in prefill mode. + + This method prepares page-level metadata needed for mha_batch_prefill_func. + The metadata is computed once per forward pass and reused across all layers. + """ + block_size = self.page_size + context_lens = self.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + # Page-level kv_indptr (reuse pa_kv_indptr buffer) + pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + # Build kv_indices from page_table using triton kernel + page_table = self.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + self.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, # kv_start_idx + self.pa_kv_indices, + page_table.stride(0), + ) + # kv_indices = self.pa_kv_indices + + # Compute kv_last_page_lens for each sequence + # kv_last_page_lens = ((context_lens - 1) % block_size + 1).int() + + # Store in ForwardMetadata for reuse in forward_extend + # self.forward_metadata.prefill_pages_kv_indptr = pages_kv_indptr + # self.forward_metadata.prefill_kv_indices = kv_indices + # self.forward_metadata.prefill_kv_last_page_lens = kv_last_page_lens + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + # Always use preshuffle layout for pa_fwd_asm + self.page_table = torch.zeros( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, + ) + self.seq_lens = torch.zeros((max_bs,), dtype=torch.int32, device=self.device) + self.strided_indices = torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ) + + # Pre-allocate buffers for pa_metadata in CUDA graph mode (non-MLA decode) + if self.decode_using_pa_ps and not self.use_mla: + # Pre-allocate pa_metadata buffers for CUDA graph compatibility + # These buffers will be reused in capture and replay phases + # Use max_bs and max_qlen=1 (decode mode) to calculate buffer sizes + # max_qlen = 1 # decode mode + # kv_dtype_for_metadata = dtypes.fp8 + ( + (work_metadata_ptrs_size, work_metadata_ptrs_type), + (work_indptr_size, work_indptr_type), + (work_info_size, work_info_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_pa_metadata_info_v1( + max_bs, + self.num_kv_head, + ) + + # Pre-allocate buffers with maximum size for CUDA graph compatibility + self._allocate_pa_metadata_buffers( + work_metadata_ptrs_size, + work_metadata_ptrs_type, + work_indptr_size, + work_indptr_type, + work_info_size, + work_info_type, + reduce_indptr_size, + reduce_indptr_type, + reduce_final_map_size, + reduce_final_map_type, + reduce_partial_map_size, + reduce_partial_map_type, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + if forward_mode.is_decode_or_idle(): + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + page_table = self.page_table[:bs, :] + self.seq_lens[:bs].copy_(seq_lens, non_blocking=True) + seq_lens_persistent = self.seq_lens[:bs] + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, # qo_indptr will be set by _build_pa_metadata_for_decode + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode + None, # max_kv_len + page_table, + seq_lens_persistent, + ) + + # Build pa_metadata using CUDA graph buffers + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + return # Early return for non-MLA decode mode + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + ): + if forward_mode.is_decode_or_idle(): + # Common setup for both MLA and non-MLA modes + page_table_persistent = self.page_table + seq_lens_persistent = self.seq_lens + seq_lens_persistent.fill_(0) + page_table_persistent.fill_(0) + seq_lens_persistent[:bs].copy_(seq_lens, non_blocking=True) + max_seq_pages = ( + seq_lens_cpu.max().item() + self.page_size - 1 + ) // self.page_size + 1 + page_table = self.req_to_token[ + req_pool_indices[:, None], + self.strided_indices[:max_seq_pages][None, :], + ] + page_table_persistent[:bs, :max_seq_pages].copy_( + page_table // self.page_size, non_blocking=True + ) + + if self.use_mla: + # MLA mode: kv_indptr and kv_indices are used in forward_decode + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + # Non-MLA decode mode: kv_indptr and kv_indices are NOT used in forward_decode + # (forward_decode uses pa_metadata_pages_kv_indptr and pa_metadata_kv_indices instead) + self.forward_metadata = ForwardMetadata( + None, # kv_indptr not used in non-MLA decode mode + None, # kv_indices not used in non-MLA decode mode + None, + None, # kv_last_page_len not used in non-MLA mode + 1, # max_q_len = 1 for decode mode, non-MTP + None, # max_kv_len + page_table_persistent[:bs, :max_seq_pages], + seq_lens_persistent[:bs], + ) + + # Rebuild pa_metadata using CUDA graph buffers (updates content, keeps same addresses) + if self.decode_using_pa_ps: + self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + else: + raise ValueError("Invalid forward mode") + + def set_kv_buffer_with_layout_shuffle( + self, + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, + ): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + "auto", + k_scale, + v_scale, + ) + + def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): + # print(f"Running forward_extend with q shape {q.shape}, k shape {k.shape}, v shape {v.shape}", flush=True) + # print(f"q dtype: {q.dtype}, k dtype: {k.dtype}, v dtype: {v.dtype}", flush=True) + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + seqlens_in_batch = forward_batch.seq_lens + cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + # use fp8 mha directly + if q.dtype != k.dtype and k.dtype == dtypes.fp8: + q = q.to(dtypes.fp8) + o = flash_attn_varlen_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=self.forward_metadata.max_q_len, + max_seqlen_k=self.forward_metadata.max_kv_len, + min_seqlen_q=0, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1, 0), + sink_ptr=None, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode_pa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + + block_size = self.page_size + num_slots, num_kv_heads, head_size = k_buffer.shape + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + + x = 16 // k_buffer.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + pa_fwd_asm( + Q=q, + K=new_key_cache, + V=new_value_cache, + block_tables=self.forward_metadata.page_table, + context_lens=self.forward_metadata.kv_lens, + block_tables_stride0=self.forward_metadata.page_table.stride(0), + K_QScale=self.k_scale, + V_QScale=self.v_scale, + out_=o, + ) + return o + + def forward_decode_pa_ps( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + # Create o as 3D tensor [batch_size, num_heads, head_dim] for both MLA and pa_fwd_asm + # In decode mode, q.shape[0] equals batch_size (each sequence has 1 token) + # Use q.shape[0] instead of forward_batch.batch_size to be safe + batch_size = q.shape[0] + head_dim_out = ( + layer.v_head_dim + if layer.qk_head_dim != layer.v_head_dim + else layer.head_dim + ) + o = q.new_empty((batch_size, layer.tp_q_head_num, head_dim_out)) + + if save_kv_cache: + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + self.set_kv_buffer_with_layout_shuffle( + forward_batch.out_cache_loc, + k, + v, + k_buffer, + v_buffer, + layer.k_scale, + layer.v_scale, + self.page_size, + ) + # Shuffle operation is already fused in rotary_emb, so just save directly + # forward_batch.token_to_kv_pool.set_kv_buffer( + # layer, forward_batch.out_cache_loc, k, v, layer.k_scale, layer.v_scale + # ) + + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + num_slots, num_kv_heads, head_size = k_buffer.shape + block_size = self.page_size + num_blocks = num_slots // block_size + k_buffer = k_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + v_buffer = v_buffer[: num_blocks * block_size].view( + num_blocks, block_size, num_kv_heads, head_size + ) + + quant_dtype = dtypes.fp8 + x = 16 // quant_dtype.itemsize + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=k_buffer.dtype, + device="meta", + ) + # V: [num_blocks, block_size, num_kv_heads, head_size] -> [num_blocks, num_kv_heads, block_size // x, head_size, x] + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=v_buffer.dtype, + device="meta", + ) + new_key_cache = k_buffer.view_as(k_cache_template) + new_value_cache = v_buffer.view_as(v_cache_template) + + total_tokens = num_blocks * block_size + k_qscale = self.k_qscale[:, :total_tokens] + v_qscale = self.v_qscale[:, :total_tokens] + + q = q.view(batch_size, layer.tp_q_head_num, layer.head_dim) + + assert ( + self.forward_metadata.pa_metadata_qo_indptr is not None + ), "pa_metadata_qo_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_pages_kv_indptr is not None + ), "pa_metadata_pages_kv_indptr should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_kv_indices is not None + ), "pa_metadata_kv_indices should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_context_lens is not None + ), "pa_metadata_context_lens should be set by _build_pa_metadata_for_decode" + assert ( + self.forward_metadata.pa_metadata_max_qlen is not None + ), "pa_metadata_max_qlen should be set by _build_pa_metadata_for_decode" + + qo_indptr = self.forward_metadata.pa_metadata_qo_indptr + kv_indptr = self.forward_metadata.pa_metadata_pages_kv_indptr + kv_indices = self.forward_metadata.pa_metadata_kv_indices + context_lens = self.forward_metadata.pa_metadata_context_lens + max_qlen = self.forward_metadata.pa_metadata_max_qlen + + _, _ = pa_persistent_fwd( + Q=q, + K=new_key_cache, + V=new_value_cache, + output=o, + max_qlen=max_qlen, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + kv_indices=kv_indices, + context_lens=context_lens, + work_indptr=self.pa_metadata_buffers["work_indptr"], + work_info=self.pa_metadata_buffers["work_info"], + reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], + K_QScale=k_qscale, + V_QScale=v_qscale, + softmax_scale=layer.scaling, + mask=1, + ) + return o.view(-1, layer.tp_q_head_num * head_dim_out) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + if self.use_mla: + raise NotImplementedError( + "MLA decode mode is not implemented yet in ATOMAttnBackendForSgl." + ) + else: + if self.decode_using_pa_ps: + return self.forward_decode_pa_ps( + q, k, v, layer, forward_batch, save_kv_cache + ) + else: + return self.forward_decode_pa( + q, k, v, layer, forward_batch, save_kv_cache + ) diff --git a/atom/plugin/config.py b/atom/plugin/config.py index bf1af0f62..21fa7f131 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -128,10 +128,13 @@ def _generate_atom_config_from_sglang_config(config: Any): from sglang.srt.configs.load_config import LoadConfig from atom.config import Config, ParallelConfig, CompilationConfig + # Format1: sglang serve --model-path ... + # Format2: python3 -m sglang.launch_server --model-path ... + args_list = sys.argv[2:] if sys.argv[1] == "serve" else sys.argv[1:] # sglang has no global config variable like vllm, # so here construct the server args from sys.argv passed by users # this is the only way to get full arguments - server_args: ServerArgs = prepare_server_args(sys.argv[1:]) + server_args: ServerArgs = prepare_server_args(args_list) sgl_model_config = SglangModelConfig.from_server_args(server_args) sgl_model_opt_config = ModelOptConfig( @@ -223,7 +226,6 @@ def generate_atom_config_for_plugin_mode(config: Any = None): """ logger.info("Generate atom config for plugin mode from passed config") - atom_config = None from atom.plugin import is_vllm, is_sglang from atom.config import set_current_atom_config diff --git a/atom/plugin/register.py b/atom/plugin/register.py index d76e2e86c..69e6b3a1e 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -26,9 +26,9 @@ def _register_custom_attention_to_sglang() -> None: @register_attention_backend("aiter") def create_atom_backend(runner): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + from atom.plugin.attention_backend.sgl_attn_backend import ATOMAttnBackendForSgl - return AiterAttnBackend(runner) + return ATOMAttnBackendForSgl(runner) def register_ops_to_sglang(atom_config: Config) -> None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 62ce11bb5..3c9100c54 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,7 @@ "ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1" ) == "1", + "ATOM_ROPE_FUSED_QKNORM": lambda: os.getenv("AITER_ROPE_FUSED_QKNORM", "0") == "1", }