Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
get_decode_att_backend_class,
get_mla_prefill_att_backend_class,
get_mla_decode_att_backend_class,
get_nsa_prefill_att_backend_class,
get_nsa_decode_att_backend_class,
)
5 changes: 5 additions & 0 deletions lightllm/common/basemodel/attention/base_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class AttControl:
mla_prefill_dict: Dict = None
mla_decode: bool = False
mla_decode_dict: Dict = None
# nsa (native sparse attention) 专用传参项
nsa_prefill: bool = False
nsa_prefill_dict: Dict = None
nsa_decode: bool = False
nsa_decode_dict: Dict = None


@dataclass
Expand Down
24 changes: 24 additions & 0 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -46,6 +47,13 @@
},
}

nsa_data_type_to_backend = {
"None": {
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
}


def _auto_select_backend(
llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"]
Expand Down Expand Up @@ -105,3 +113,19 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla
return mla_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list)


def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend:
llm_dtype = "None"
if backend_str not in nsa_data_type_to_backend[llm_dtype]:
logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse")
backend_str = "flashmla_sparse"
return nsa_data_type_to_backend[llm_dtype][backend_str]


def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend:
llm_dtype = "None"
if backend_str not in nsa_data_type_to_backend[llm_dtype]:
logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse")
backend_str = "flashmla_sparse"
return nsa_data_type_to_backend[llm_dtype][backend_str]
13 changes: 13 additions & 0 deletions lightllm/common/basemodel/attention/nsa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""NSA (Native Sparse Attention) backend implementations."""

from .flashmla_sparse import (
NsaFlashMlaSparseAttBackend,
NsaFlashMlaSparsePrefillAttState,
NsaFlashMlaSparseDecodeAttState,
)

__all__ = [
"NsaFlashMlaSparseAttBackend",
"NsaFlashMlaSparsePrefillAttState",
"NsaFlashMlaSparseDecodeAttState",
]
134 changes: 134 additions & 0 deletions lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py
# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library.

import dataclasses
import torch
from typing import Tuple, TYPE_CHECKING

from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from lightllm.utils.dist_utils import get_current_device_id

if TYPE_CHECKING:
from lightllm.common.basemodel.infer_struct import InferStateInfo


class NsaFlashMlaSparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)

def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState":
return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState":
return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state)


@dataclasses.dataclass
class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState):
"""Prefill attention state for NSA using flash_mla_sparse_fwd."""

cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None

def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

def prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"

return self._nsa_prefill_att(q=q, kv=k, att_control=att_control)

def _nsa_prefill_att(
self,
q: torch.Tensor,
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_mla import flash_mla_sparse_fwd

nsa_dict = att_control.nsa_prefill_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out


@dataclasses.dataclass
class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState):

cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None

def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

def decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"

return self._nsa_decode_att(q=q, kv=k, att_control=att_control)

def _nsa_decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_attn import flash_attn_with_kvcache

nsa_dict = att_control.nsa_decode_dict
topk_indices = nsa_dict["topk_indices"]
nsa_cache_seqlens = nsa_dict["nsa_cache_seqlens"]
nsa_cu_seqlens_k = nsa_dict["nsa_cu_seqlens_k"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
qk_rope_head_dim = nsa_dict["qk_rope_head_dim"]

q_nope, q_rope = q

# Extract k_rope and kv_nope from the KV buffer
k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim)
kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank)

o_tensor = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope,
v_cache=kv_nope,
qv=q_nope,
page_table=topk_indices,
cache_seqlens=nsa_cache_seqlens,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=nsa_cu_seqlens_k,
max_seqlen_q=self.infer_state.max_q_seq_len,
softmax_scale=softmax_scale,
causal=True,
)
return o_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _init_parallel_params(self):
self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_
if self.enable_ep_moe:
assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe"
logger.info(
logger.debug(
f"global_rank {self.global_rank_} layerindex {self.layer_num_} "
f"redundancy_expertids: {self.redundancy_expert_ids}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,20 @@ def load_hf_weights(self, weights):
"""
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2:
if isinstance(attr, TransformerLayerWeight):
attr.load_hf_weights(weights)
elif isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2:
with self.lock:
attr.load_hf_weights(weights)
elif isinstance(attr, BaseWeight):
attr.load_hf_weights(weights)

def verify_load(self):
for attr_name in dir(self):
attr = getattr(self, attr_name, None)
if isinstance(attr, TransformerLayerWeight):
attr.verify_load()
super().verify_load()

def get_quant_method(self, name):
return self.quant_cfg.get_quant_method(self.layer_num_, name)
26 changes: 11 additions & 15 deletions lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req):
else:
self.req_to_out_token_id_counter[req.req_idx].fill_(0)
if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics:
prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True)
prompt_ids = g_pin_mem_manager.gen_from_list(
key="prompt_ids_for_penalty",
data=req.shm_req.get_prompt_ids_numpy(),
dtype=torch.int32,
).cuda(non_blocking=True)
token_id_counter(
prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx]
)
Expand Down Expand Up @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
cum_sum_len += len(id_to_count)
p_cumsum_seq_len.append(cum_sum_len)

from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_ids", size=len(p_token_ids), dtype=torch.int32
)
p_token_ids_tensor.numpy()[:] = p_token_ids

p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_counts", size=len(p_token_counts), dtype=torch.int32
p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32)
p_token_counts_tensor = g_pin_mem_manager.gen_from_list(
key="p_token_counts", data=p_token_counts, dtype=torch.int32
)
p_token_counts_tensor.numpy()[:] = p_token_counts

p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32
p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list(
key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32
)
p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len

return (
p_token_ids_tensor.cuda(non_blocking=True),
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel
from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
Expand Down
Empty file.
Loading
Loading