From f6191a05ab8fd9d66d98fda2c7c1e14a85cb4003 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 03:55:14 +0000 Subject: [PATCH 01/23] support deepseek v3.2 --- .../layer_weights/transformer_layer_weight.py | 11 +- lightllm/models/deepseek3_2/infer_struct.py | 9 ++ .../layer_infer/nsa_indexer_layer_inder.py | 142 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 127 ++++++++++++++++ .../layer_weights/nsa_indexer_layer_weight.py | 49 ++++++ .../layer_weights/transformer_layer_weight.py | 16 ++ lightllm/models/deepseek3_2/mem_manager.py | 47 ++++++ lightllm/models/deepseek3_2/model.py | 38 +++++ .../deepseek3_2/triton_kernel/__init__.py | 0 .../deepseek3_2/triton_kernel/act_quant.py | 137 +++++++++++++++++ .../triton_kernel/token_group_quant.py | 103 +++++++++++++ 11 files changed, 678 insertions(+), 1 deletion(-) create mode 100644 lightllm/models/deepseek3_2/infer_struct.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/mem_manager.py create mode 100644 lightllm/models/deepseek3_2/model.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/__init__.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/act_quant.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 86a887a25..0f4d6b13a 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,11 +36,20 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() + def get_quant_method(self, name): return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py new file mode 100644 index 000000000..6e5e766b2 --- /dev/null +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -0,0 +1,9 @@ +import os +import torch +import numpy as np +import torch.distributed as dist +from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo + + +class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): + pass \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py new file mode 100644 index 000000000..a3891f0f3 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -0,0 +1,142 @@ +from sgl_kernel import fast_topk_transform_fused +import deep_gemm +import torch +import torch.nn.functional as F + +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant + + +class NSAIndexerInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config, mode=[]): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.mode = mode + self.index_topk = network_config["index_topk"] + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = self.index_n_heads ** -0.5 + + self.q_lora = None + self.hidden_states = None + return + + def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + assert self.hidden_states is not None + assert self.q_lora is not None + + q, k = self._get_q_k_bf16(infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + + logits = fp8_paged_mqa_logits_torch( + q_fp8, k_fp8, weights, + infer_state.lengths, + infer_state.page_table, + infer_state.max_model_len + ) + + return fast_topk_transform_fused( + score=logits, + lengths=infer_state.lengths, + page_table_size_1=infer_state.page_table, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, + topk=self.index_topk + ) + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert ( + hidden_size & (hidden_size - 1) + ) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size**-0.5) + + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) + self.q_lora = None + + k = layer_weight.wk_proj_.mm(self.hidden_states) + self.hidden_states = None + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k + + +# TODO +def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int): + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + for block_rk in range((context_len + block_size - 1) // block_size): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..6db8c14e8 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,127 @@ +from functools import partial +from typing import override + +import torch +from sgl_kernel.flash_mla import flash_mla_sparse_fwd +from sgl_kernel.flash_attn import flash_attn_with_kvcache + +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config, mode) + + self.indexer = NSAIndexerInfer( + layer_idx=self.layer_num_, + network_config=self.network_config_, + mode=mode + ) + return + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + if self.q_lora_rank is None: + q = layer_weight.q_weight_.mm(input) + cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + else: + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + self.indexer.hidden_states = input + self.indexer.q_lora = q + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + @override + def _bind_attention(self): + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + pass + + def _context_attention_flashmla_kernel_with_indexer( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek3_2FlashInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + mla_out, _, _ = flash_mla_sparse_fwd( + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + indices=topk_indices.unsqueeze(1), + sm_scale=self.softmax_scale, + d_v=self.kv_lora_rank, + ) + return mla_out + + def _token_attention_flashmla_kernel_with_indexer( + self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=infer_state.b_att_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.max_q_seq_len, + softmax_scale=self.softmax_scale, + causal=True, + softcap=0.0, + return_softmax_lse=False, + num_splits=0, # TODO enable_deterministic_inference + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py new file mode 100644 index 000000000..47e0bfdac --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -0,0 +1,49 @@ +from typing_extensions import override + +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight + + +class NSAIndexerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wq_b", + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wk", + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = NormWeight( + f"{prefix}.k_norm.weight", + torch.float32, + bias_name=f"{prefix}.k_norm.bias" + ) + self.weights_proj_ = ROWMMWeight( + weight_name=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_cfg=None, + layer_num=self.layer_num_, + name="weights_proj", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2a03e1d6a --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,16 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + self.indexer_layer_weight = NSAIndexerWeight( + layer_num=layer_num, + data_type=data_type, + network_config=network_config, + mode=mode, + quant_cfg=quant_cfg + ) + return diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py new file mode 100644 index 000000000..0aa0a0bdb --- /dev/null +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -0,0 +1,47 @@ +from typing_extensions import override +import torch + +from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager + + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__( + self, + size, + dtype, + head_num, + head_dim, + layer_num, + index_head_dim, + index_quant_block_size, + k_cache_dtype=torch.float8_e4m3fn, + k_scale_dtype=torch.float32, + always_copy=False, + mem_fraction=0.9 + ): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" + self.index_head_dim = index_head_dim + self.index_quant_block_size = index_quant_block_size + self.k_cache_dtype = k_cache_dtype + self.k_scale_dtype = k_scale_dtype + return + + @override + def get_cell_size(self): + index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) + index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) + return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + + @override + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + self._init_indexer_k_cache_buffers() + return + + def _init_indexer_k_cache_buffers(self): + self.indexer_k_cache_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") + self.indexer_k_scale_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") + return diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 000000000..3a244c77f --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,38 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo + +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek3_2FlashInferStateInfo + + def _init_mem_manager(self): + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + self.mem_manager = Deepseek3_2MemoryManager( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + index_head_dim = self.config["index_head_dim"], + index_quant_block_size = self.config["index_quant_block_size"], + mem_fraction=self.mem_fraction, + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 000000000..a4ecd0f51 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 000000000..dbf5c5199 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( + tl.float32 + ) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file From 9a5c4108317c2e8ff96b3c95bca699ef3abd3763 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 06:22:57 +0000 Subject: [PATCH 02/23] fix --- lightllm/models/deepseek3_2/infer_struct.py | 7 ++- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/deepseek3_2/mem_manager.py | 63 +++++++++++++------ lightllm/models/deepseek3_2/model.py | 2 - 5 files changed, 57 insertions(+), 26 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 6e5e766b2..20f8b7e8d 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -3,7 +3,12 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - pass \ No newline at end of file + + def __init__(self): + super().__init__() + assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) + self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index a3891f0f3..100df16f9 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -28,7 +28,7 @@ def __init__(self, layer_idx, network_config, mode=[]): self.scale_fmt = network_config["quantization_config"]["scale_fmt"] self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) self.index_n_heads = network_config["index_n_heads"] - self.index_n_heads_scale = self.index_n_heads ** -0.5 + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale self.q_lora = None self.hidden_states = None @@ -67,8 +67,13 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + # write + # infer_state.mem_manager. + + # read + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + weights = weights.unsqueeze(-1) * q_scale logits = fp8_paged_mqa_logits_torch( q_fp8, k_fp8, weights, diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 6db8c14e8..9f503e9bd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -76,7 +76,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashInferStateInfo, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index 0aa0a0bdb..f2613aacc 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,9 +1,37 @@ from typing_extensions import override import torch +from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) +class Deepseek3_2IndexerPagedMemoryManager: + def __init__(self, page_size): + self.page_size = page_size + return + + def set_size(self, size): + self.physics_size = size + self.num_pages = size // self.page_size + return + + def _init_buffers(self): + self.k_cache_buffer = torch.empty( + (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") + self.k_scale_buffer = torch.empty( + (self.page_size, 1), dtype=torch.float64, device="cuda") + return + + def alloc_paged_index(self, last_index: int, need_size): + pass + + def get_cell_size(self): + # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) + return 128 + 4 + + class Deepseek3_2MemoryManager(Deepseek2MemoryManager): def __init__( self, @@ -12,36 +40,31 @@ def __init__( head_num, head_dim, layer_num, - index_head_dim, - index_quant_block_size, - k_cache_dtype=torch.float8_e4m3fn, - k_scale_dtype=torch.float32, always_copy=False, - mem_fraction=0.9 + mem_fraction=0.9, + page_size=64 ): + self.page_size = page_size + self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" - self.index_head_dim = index_head_dim - self.index_quant_block_size = index_quant_block_size - self.k_cache_dtype = k_cache_dtype - self.k_scale_dtype = k_scale_dtype + self.indexer_paged_mem_manager.set_size(self.size) return @override def get_cell_size(self): - index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) - index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) - return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() @override def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self._init_indexer_k_cache_buffers() + self.indexer_paged_mem_manager._init_buffers() return - def _init_indexer_k_cache_buffers(self): - self.indexer_k_cache_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") - self.indexer_k_scale_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") - return + @override + def profile_size(self, mem_fraction): + super().profile_size(mem_fraction) + if self.size % self.page_size != 0: + size_paged = (self.size // self.page_size + 1) * self.page_size + logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") + self.size = size_paged + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 3a244c77f..5b3fc1f13 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -31,8 +31,6 @@ def _init_mem_manager(self): head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, - index_head_dim = self.config["index_head_dim"], - index_quant_block_size = self.config["index_quant_block_size"], mem_fraction=self.mem_fraction, ) return \ No newline at end of file From 3f49014e861fe1bacaeaa3c35cd29448ed4bb68e Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 6 Nov 2025 10:40:46 +0000 Subject: [PATCH 03/23] fix --- .../deepseek2_mem_manager.py | 4 +- .../kv_cache_mem_manager/mem_manager.py | 17 ++- lightllm/models/__init__.py | 1 + lightllm/models/deepseek3_2/infer_struct.py | 26 +++- .../layer_infer/nsa_indexer_layer_inder.py | 136 ++++++++++------- .../layer_infer/transformer_layer_infer.py | 15 +- lightllm/models/deepseek3_2/mem_manager.py | 72 ++------- lightllm/models/deepseek3_2/model.py | 13 +- .../destindex_copy_indexer_ks.py | 137 ++++++++++++++++++ .../triton_kernel/fp8_mqa_logits.py | 0 10 files changed, 283 insertions(+), 138 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3d93e1b07..ad54b3935 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -15,8 +15,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec..2940d74e2 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,7 +26,7 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -48,15 +48,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + if not is_sub_mem_manager: + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 095f73679..fdd277f36 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 20f8b7e8d..bfdb53fd6 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,14 +1,24 @@ -import os import torch -import numpy as np -import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager +class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): -class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - def __init__(self): super().__init__() - assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) - self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager + self.lengths = None + self.page_table_size_1 = None + self.ks = None + self.ke = None + return + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + # Ensure b_ready_cache_len is set for both prefill and decode modes + if self.is_prefill: + # b_ready_cache_len is already set in basemodel.py for prefill + pass + else: + # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len + # since b_q_seq_len represents the new tokens being processed + if self.b_ready_cache_len is None: + self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 100df16f9..1977c211e 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -5,10 +5,12 @@ from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant - +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -30,8 +32,6 @@ def __init__(self, layer_idx, network_config, mode=[]): self.index_n_heads = network_config["index_n_heads"] self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale - self.q_lora = None - self.hidden_states = None return def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, @@ -59,7 +59,7 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: assert self.hidden_states is not None assert self.q_lora is not None @@ -67,29 +67,78 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - # write - # infer_state.mem_manager. - - # read + self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale - - logits = fp8_paged_mqa_logits_torch( - q_fp8, k_fp8, weights, - infer_state.lengths, - infer_state.page_table, - infer_state.max_model_len + weights = weights.unsqueeze(-1) * q_scale + + ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + + k_fp8_list = [] + k_scale_list = [] + ks_list = [] + ke_list = [] + offset = 0 + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() + k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() + ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") + ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 + k_fp8_list.append(k_fp8) + k_scale_list.append(k_scale) + ks_list.append(ks) + ke_list.append(ke) + offset += q_len + + k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) + kv_fp8 = (k_fp8, k_scale) + ks = torch.cat(ks_list, dim=0) + ke = torch.cat(ke_list, dim=0) + + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights.squeeze(-1), + ks, + ke, + clean_logits=False, ) - return fast_topk_transform_fused( - score=logits, - lengths=infer_state.lengths, - page_table_size_1=infer_state.page_table, - cu_seqlens_q=infer_state.b1_cu_q_seq_len, - topk=self.index_topk - ) - + return self.get_topk(logits, infer_state) + + def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): + topk_indices_list = [] + offset = 0 + + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + end_pos = q_len + cache_len + # Slice logits for this batch (both query and sequence dimensions) + batch_logits = logits[offset:offset + q_len, :end_pos] + topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") + for j in range(q_len): + indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] + topk_indices_list.append(indices) + offset += q_len + + topk_indices_ = torch.cat(topk_indices_list, dim=0) + + return topk_indices_ + + + def get_k_float32_from_buffer(self, buffer: torch.Tensor): + k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] + k_float32 = k_fp8.float() * k_scale + return k_float32 + @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -101,12 +150,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) self.q_lora = None k = layer_weight.wk_proj_.mm(self.hidden_states) - self.hidden_states = None k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -122,26 +170,16 @@ def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, la k = self._rotate_activation(k) return q, k - -# TODO -def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, - weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, - max_model_len: int): - batch_size, next_n, heads, dim = q.size() - num_block, block_size, _, dim = kv_cache.size() - logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) - context_lens = context_lens.tolist() - for i in range(batch_size): - context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') - weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() - for block_rk in range((context_len + block_size - 1) // block_size): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) - s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) - return logits \ No newline at end of file + def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): + # k_fp8 : [seq_len, 128] torch.fp8_e4m3 + # k_scale : [seq_len, 1] torch.float32 + # mem_index : [seq_len] torch.int32 + # buffer : [10000000, 1, 132] torch.uint8 + buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] + k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] + mem_index, + buffer + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9f503e9bd..076d3965c 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -8,7 +8,7 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -30,7 +30,7 @@ def __init__(self, layer_num, network_config, mode=[]): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -68,6 +68,7 @@ def _get_qkv( @override def _bind_attention(self): + super()._bind_attention() self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) pass @@ -76,7 +77,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -87,18 +88,19 @@ def _context_attention_flashmla_kernel_with_indexer( topk_indices = self.indexer.get_indices( infer_state, layer_weight.indexer_layer_weight, - ) + ).unsqueeze(1) + mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices.unsqueeze(1), + indices=topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out def _token_attention_flashmla_kernel_with_indexer( - self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -125,3 +127,4 @@ def _token_attention_flashmla_kernel_with_indexer( return_softmax_lse=False, num_splits=0, # TODO enable_deterministic_inference ) + return o \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index f2613aacc..a70c76273 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,70 +1,22 @@ +from typing import List from typing_extensions import override import torch -from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.mem_manager import MemoryManager from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.utils.log_utils import init_logger +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator -logger = init_logger(__name__) - -class Deepseek3_2IndexerPagedMemoryManager: - def __init__(self, page_size): - self.page_size = page_size - return - - def set_size(self, size): - self.physics_size = size - self.num_pages = size // self.page_size - return - - def _init_buffers(self): - self.k_cache_buffer = torch.empty( - (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") - self.k_scale_buffer = torch.empty( - (self.page_size, 1), dtype=torch.float64, device="cuda") - return - - def alloc_paged_index(self, last_index: int, need_size): - pass - - def get_cell_size(self): - # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) - return 128 + 4 - - class Deepseek3_2MemoryManager(Deepseek2MemoryManager): - def __init__( - self, - size, - dtype, - head_num, - head_dim, - layer_num, - always_copy=False, - mem_fraction=0.9, - page_size=64 - ): - self.page_size = page_size - self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - self.indexer_paged_mem_manager.set_size(self.size) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) return @override def get_cell_size(self): - return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() - - @override - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self.indexer_paged_mem_manager._init_buffers() - return - - @override - def profile_size(self, mem_fraction): - super().profile_size(mem_fraction) - if self.size % self.page_size != 0: - size_paged = (self.size // self.page_size + 1) * self.page_size - logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") - self.size = size_paged - return \ No newline at end of file + return super().get_cell_size() + 132 + +class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 5b3fc1f13..c4e56c3c1 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -3,9 +3,8 @@ from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo - +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class @@ -15,9 +14,13 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer # infer state class - infer_state_class = Deepseek3_2FlashInferStateInfo + infer_state_class = Deepseek3_2FlashAttentionStateInfo def _init_mem_manager(self): + manager_class = Deepseek3_2MemoryManager + if "triton_fp8kv" in self.mode: + manager_class = Deepseek3_2FP8KVMemoryManager + # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 if get_env_start_args().mtp_mode == "deepseekv3_eagle": @@ -25,7 +28,7 @@ def _init_mem_manager(self): elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": added_mtp_layer_num += get_env_start_args().mtp_step - self.mem_manager = Deepseek3_2MemoryManager( + self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 000000000..a098795fb --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,137 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + stride_k_fp8_bs, + stride_k_fp8_h, + stride_k_fp8_d, + stride_k_scale_bs, + stride_k_scale_h, + stride_k_scale_d, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data + k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] + k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Load k_scale data + k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] + k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Store k_fp8 to buffer_fp8 + buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] + tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) + + # Store k_scale to buffer_scale + buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] + tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + +@torch.no_grad() +def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): + seq_len = mem_index.shape[0] + head_num = k_fp8.shape[1] + k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 + k_scale_dim = k_scale.shape[2] # Should be 1 + + assert k_fp8.shape[1] == k_scale.shape[1] + assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" + assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes + + # Reinterpret buffer as the appropriate types for storing + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + BLOCK_HEAD = triton.next_power_of_2(head_num) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + k_fp8.stride(0), + k_fp8.stride(1), + k_fp8.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + head_num, + BLOCK_DMODEL=k_fp8_dim, + BLOCK_HEAD=BLOCK_HEAD, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test(): + import torch.nn.functional as F + + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim + seq_len = 50 # number of tokens to copy + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test data + k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() + k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() + mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + + # Create buffer [total_tokens, heads, 132] + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Call the function + destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) + + # Verify results + for i in range(seq_len): + dest_idx = mem_index[i].item() + # Check k_fp8 part + stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) + expected_fp8 = k_fp8[i] + assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" + + # Check k_scale part + stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] + expected_scale = k_scale[i] + assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" + + print("All tests passed!") + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 000000000..e69de29bb From f7773988802d7347a05f43a1d17669c4d3e73ada Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 09:16:40 +0000 Subject: [PATCH 04/23] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 + .../layer_infer/nsa_indexer_layer_inder.py | 17 +++---- .../layer_infer/transformer_layer_infer.py | 50 ++++++++----------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index bfdb53fd6..4d77b5f6f 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,6 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None + + self.topk_indices = None return def init_some_extra_state(self, model, input_ids: torch.Tensor): diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 1977c211e..3e5e1c266 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -59,17 +59,16 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - assert self.hidden_states is not None - assert self.q_lora is not None + def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - q, k = self._get_q_k_bf16(infer_state, layer_weight) + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) - weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] @@ -150,11 +149,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): - q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) - self.q_lora = None + def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(self.hidden_states) + k = layer_weight.wk_proj_.mm(hidden_states) k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 076d3965c..01514e96a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,6 @@ from functools import partial from typing import override +from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd @@ -24,6 +25,7 @@ def __init__(self, layer_num, network_config, mode=[]): network_config=self.network_config_, mode=mode ) + self.topk_indices = None return @override @@ -35,20 +37,15 @@ def _get_qkv( ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.q_lora_rank is None: - q = layer_weight.q_weight_.mm(input) - cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - else: - q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 - ) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - self.indexer.hidden_states = input - self.indexer.q_lora = q + self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) - q = layer_weight.q_b_proj_.mm(q) - cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) rmsnorm_forward( @@ -69,11 +66,11 @@ def _get_qkv( @override def _bind_attention(self): super()._bind_attention() - self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) - self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass - def _context_attention_flashmla_kernel_with_indexer( + def _nsa_context_attention_kernel( self, q: torch.Tensor, kv, @@ -85,21 +82,17 @@ def _context_attention_flashmla_kernel_with_indexer( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ).unsqueeze(1) mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices, + indices=self.topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out - def _token_attention_flashmla_kernel_with_indexer( + def _nsa_token_attention_kernel( self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] @@ -107,24 +100,23 @@ def _token_attention_flashmla_kernel_with_indexer( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ) - o = flash_attn_with_kvcache( + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, - page_table=topk_indices, + page_table=self.topk_indices, cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, + window_size=(-1, -1), softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, return_softmax_lse=False, - num_splits=0, # TODO enable_deterministic_inference ) - return o \ No newline at end of file + return o_tensor \ No newline at end of file From 3ce1629a34e67970fdcaa86568e7b5be4f211dec Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 10:09:56 +0000 Subject: [PATCH 05/23] need fix --- lightllm/models/deepseek3_2/__init__.py | 0 lightllm/models/deepseek3_2/infer_struct.py | 10 ++++++++-- .../layer_infer/transformer_layer_infer.py | 12 +++--------- lightllm/models/deepseek3_2/model.py | 5 +++++ 4 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 lightllm/models/deepseek3_2/__init__.py diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 4d77b5f6f..b1e61413c 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,8 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None - - self.topk_indices = None + self.nsa_cu_seqlens_k = None + self.index_topk = 2048 return def init_some_extra_state(self, model, input_ids: torch.Tensor): @@ -24,3 +24,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + assert self.nsa_cache_seqlens.dtype == torch.int32 + self.nsa_cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) + ) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 01514e96a..188ab8b4a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -86,7 +86,7 @@ def _nsa_context_attention_kernel( mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=self.topk_indices, + indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) @@ -100,23 +100,17 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, page_table=self.topk_indices, - cache_seqlens=infer_state.b_att_seq_len, + cache_seqlens=infer_state.nsa_cache_seqlens, cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, ) return o_tensor \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index c4e56c3c1..ad7f70550 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -16,6 +16,11 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # infer state class infer_state_class = Deepseek3_2FlashAttentionStateInfo + def __init__(self, kvargs): + super().__init__(kvargs) + self.index_topk = self.config["index_topk"] + return + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: From 6ac528f5c5f4b8e01347602b4f693bf7490f8a88 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:07:54 +0000 Subject: [PATCH 06/23] run like deepseek v3 --- lightllm/models/deepseek3_2/infer_struct.py | 43 +++++- .../layer_infer/nsa_indexer_layer_inder.py | 104 ++++--------- .../layer_infer/transformer_layer_infer.py | 22 +-- lightllm/models/deepseek3_2/model.py | 5 +- .../triton_kernel/fp8_mqa_logits.py | 139 ++++++++++++++++++ 5 files changed, 224 insertions(+), 89 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index b1e61413c..8e5eb0b81 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,5 +1,6 @@ import torch from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): @@ -15,6 +16,9 @@ def __init__(self): def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) + self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + # Ensure b_ready_cache_len is set for both prefill and decode modes if self.is_prefill: # b_ready_cache_len is already set in basemodel.py for prefill @@ -24,9 +28,42 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) assert self.nsa_cache_seqlens.dtype == torch.int32 self.nsa_cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) \ No newline at end of file + ) + + # Pre-compute NSA indexer indexing structures + self._init_nsa_indexing_structures() + + def _init_nsa_indexing_structures(self): + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" + mem_index_list = [] + ks_list = [] + ke_list = [] + lengths_list = [] + offset = 0 + num_seq_len = self.b_req_idx.shape[0] + self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i] + q_seq_len = self.b_q_seq_len[i] + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len + + self.mem_index = torch.cat(mem_index_list, dim=0) + # ks : [seq_len_q] 标志kv的起始位置 + # ke : [seq_len_q] 标志kv的结束位置 + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 3e5e1c266..d7444e918 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,7 +10,9 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks -# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -66,70 +68,37 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), + k_scale.unsqueeze(1), + infer_state.mem_index, + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - - k_fp8_list = [] - k_scale_list = [] - ks_list = [] - ke_list = [] - offset = 0 - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() - k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() - ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") - ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 - k_fp8_list.append(k_fp8) - k_scale_list.append(k_scale) - ks_list.append(ks) - ke_list.append(ke) - offset += q_len - - k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) - k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) - kv_fp8 = (k_fp8, k_scale) - ks = torch.cat(ks_list, dim=0) - ke = torch.cat(ke_list, dim=0) - - logits = deep_gemm.fp8_mqa_logits( - q_fp8, - kv_fp8, - weights.squeeze(-1), - ks, - ke, - clean_logits=False, - ) - - return self.get_topk(logits, infer_state) - - def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): - topk_indices_list = [] - offset = 0 - - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - end_pos = q_len + cache_len - # Slice logits for this batch (both query and sequence dimensions) - batch_logits = logits[offset:offset + q_len, :end_pos] - topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") - for j in range(q_len): - indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] - topk_indices_list.append(indices) - offset += q_len + # Use pre-computed indexing structures from infer_state + mem_index = infer_state.mem_index + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 - topk_indices_ = torch.cat(topk_indices_list, dim=0) + # TODO + k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() + k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() - return topk_indices_ + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 + return fast_topk_transform_fused( + score=logits, # [seq_len_q, seq_len_kv] + lengths=lengths, # [seq_len_q] + page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 + cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + topk=self.index_topk, + ) def get_k_float32_from_buffer(self, buffer: torch.Tensor): @@ -152,8 +121,9 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(hidden_states) + + # TODO k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -168,17 +138,3 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = self._rotate_activation(q) k = self._rotate_activation(k) return q, k - - def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): - # k_fp8 : [seq_len, 128] torch.fp8_e4m3 - # k_scale : [seq_len, 1] torch.float32 - # mem_index : [seq_len] torch.int32 - # buffer : [10000000, 1, 132] torch.uint8 - buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] - k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] - mem_index, - buffer - ) - return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 188ab8b4a..ed351312f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,10 +82,9 @@ def _nsa_context_attention_kernel( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + q=q_all, # [seq_len_q, q_num_head, qk_dim] + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,15 +99,16 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + q=q_rope, # (q_seqlen, nheads, qk_headdim) + k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) + v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) + qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) + page_table=self.topk_indices, # (q_seqlen, max_seq_len) + cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. + cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index ad7f70550..b80094488 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -@ModelRegistry(["deepseek_v32"]) +# @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight @@ -21,6 +21,9 @@ def __init__(self, kvargs): self.index_topk = self.config["index_topk"] return + def _init_inferstate_cls(self): + self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index e69de29bb..2fc92662a 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,139 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, + CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, + seq_len, seq_len_kv, num_heads, head_dim, + stride_q_seq, stride_q_head, stride_q_dim, + stride_kv_pool, stride_kv_dim, + stride_w_seq, stride_w_head, + stride_o_seq, stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute the range of seq positions this block handles + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + # Offset arrays for this block + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator for logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Create masks + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + # Load mem_indices for the KV positions + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + # Load scales for K + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + # Loop over all heads + for h in range(num_heads): + # Load weights for this head + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, + mask=mask_m, other=0.0) + + # Initialize score accumulator for this head + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over head_dim in blocks + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + # Load Q for this head and dimension block + # Q shape: (seq_len, num_heads, head_dim) + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + # Load K for this dimension block + # KV shape: (pool_size, head_dim) as FP8 data + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + # Apply scale to K (scale is per-row of K) + k = k * scales[:, None] + + # Compute partial dot product: q @ k.T + # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) + # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) + score += tl.dot(q, tl.trans(k)) + + # Apply ReLU to score + score = tl.maximum(score, 0.0) + + # Multiply by weights and accumulate to logits + logits += score * weights[:, None] + + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + # Apply mask (-inf for masked positions) + logits = tl.where(mask_valid, logits, float('-inf')) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, kv, kv_scale, weights, mem_index, + cu_seqlen_ks, cu_seqlen_ke, output, + seq_len, seq_len_kv, num_heads, head_dim, + q.stride(0), q.stride(1), q.stride(2), + kv.stride(0), kv.stride(1), + weights.stride(0), weights.stride(1), + output.stride(0), output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output \ No newline at end of file From 976613475d6f0abbe5f2e11117ecbecee8c7cbf5 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:41:25 +0000 Subject: [PATCH 07/23] fix --- .../layer_infer/nsa_indexer_layer_inder.py | 12 +- .../triton_kernel/extract_indexer_ks.py | 156 ++++++++++++++++++ 2 files changed, 160 insertions(+), 8 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d7444e918..173196bf4 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,6 +10,8 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,16 +80,13 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - # Use pre-computed indexing structures from infer_state mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - # TODO - k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() - k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() + k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -123,10 +122,7 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - # TODO - k = F.layer_norm( - k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps - ).type_as(k) + k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 000000000..e97454ba2 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,156 @@ +import torch +import triton +import triton.language as tl +import numpy + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + stride_k_fp8_out_bs, + stride_k_fp8_out_d, + stride_k_scale_out_bs, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + + # Load the memory index + mem_idx = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] + offs_d = tl.arange(0, BLOCK_DMODEL) + k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d + k_fp8_data = tl.load(k_fp8_ptrs) + + # Load k_scale data from buffer_scale[mem_idx, 0, 0] + k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d + k_scale_data = tl.load(k_scale_ptr) + + # Store k_fp8 output + k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d + tl.store(k_fp8_out_ptrs, k_fp8_data) + + # Store k_scale output + k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs + tl.store(k_scale_out_ptr, k_scale_data) + + +@torch.no_grad() +def extract_indexer_ks(buffer, mem_index): + """ + Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. + + Args: + buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 + mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 + + Returns: + k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn + k_scale: Tensor of shape [seq_len] with dtype float32 + """ + seq_len = mem_index.shape[0] + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" + + # Reinterpret buffer as the appropriate types for Triton + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + # Prepare output tensors + k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) + k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) + + BLOCK_DMODEL = 128 + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + k_fp8_out.stride(0), + k_fp8_out.stride(1), + k_scale_out.stride(0), + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + + return k_fp8_out, k_scale_out + + +def test(): + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) + seq_len = 50 # number of tokens to extract + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test buffer [total_tokens, heads, 132] as uint8 + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Fill with test data - simulate what destindex_copy_indexer_ks does + test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + # Generate fp8 data by converting from float32 + test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() + test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) + test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() + + # Manually populate buffer as destindex_copy_indexer_ks would + for i in range(seq_len): + dest_idx = test_indices[i].item() + # Store fp8 data + buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) + # Store scale data (4 bytes) - need to convert float32 to bytes + scale_bytes = test_k_scale[i].cpu().numpy().tobytes() + scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) + buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) + + # Call our extraction function + extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) + + # Verify results + print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") + print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") + print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") + print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") + + # Check if extraction matches (convert fp8 to float32 for comparison) + # Use higher tolerance for fp8 due to quantization precision + fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) + scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) + + print(f"FP8 data matches: {fp8_match}") + print(f"Scale data matches: {scale_match}") + + if fp8_match and scale_match: + print("All tests passed!") + else: + print("Test failed!") + if not fp8_match: + print("First few fp8 values:") + print(f"Original: {test_k_fp8_fp32[0, :5]}") + print(f"Extracted: {extracted_fp8.float()[0, :5]}") + if not scale_match: + print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") + + +if __name__ == "__main__": + test() From 06f042902483b9255ab854afc4c699b6d8b4acb8 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:49:38 +0000 Subject: [PATCH 08/23] fix --- lightllm/models/deepseek3_2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index b80094488..8f1ba85cf 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -# @ModelRegistry(["deepseek_v32"]) +@ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight From 6b5a1b1a4323b815ef36ef98d3cec89b868f95e2 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:51:11 +0000 Subject: [PATCH 09/23] fix --- .../models/deepseek3_2/layer_infer/transformer_layer_infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index ed351312f..5fc33d5aa 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ from functools import partial from typing import override -from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd From 9d06d87e99798af1fdcb97f1d0a0f7c9f1159525 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:53:59 +0000 Subject: [PATCH 10/23] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 -- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++++----- .../layer_infer/transformer_layer_infer.py | 20 +++++++++---------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 8e5eb0b81..e955c3bbd 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -62,8 +62,6 @@ def _init_nsa_indexing_structures(self): offset += seq_len self.mem_index = torch.cat(mem_index_list, dim=0) - # ks : [seq_len_q] 标志kv的起始位置 - # ke : [seq_len_q] 标志kv的结束位置 self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 173196bf4..d5032e72f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -90,12 +90,11 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 return fast_topk_transform_fused( - score=logits, # [seq_len_q, seq_len_kv] - lengths=lengths, # [seq_len_q] - page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 - cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, topk=self.index_topk, ) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5fc33d5aa..5b550ab09 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,8 +82,8 @@ def _nsa_context_attention_kernel( q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, # [seq_len_q, q_num_head, qk_dim] - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,14 +100,14 @@ def _nsa_token_attention_kernel( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) o_tensor = flash_attn_with_kvcache( - q=q_rope, # (q_seqlen, nheads, qk_headdim) - k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) - v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) - qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) - page_table=self.topk_indices, # (q_seqlen, max_seq_len) - cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. - cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, From ddb086dcd68479b6054815a787ab53fb7cadf60b Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 13:57:17 +0000 Subject: [PATCH 11/23] can run without cudagraph --- lightllm/models/deepseek3_2/infer_struct.py | 6 +- .../layer_infer/nsa_indexer_layer_inder.py | 24 +- .../layer_infer/transformer_layer_infer.py | 6 +- .../destindex_copy_indexer_ks.py | 354 ++++++++++----- .../triton_kernel/extract_indexer_ks.py | 409 ++++++++++++------ 5 files changed, 547 insertions(+), 252 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index e955c3bbd..c122c6a7e 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -40,7 +40,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): def _init_nsa_indexing_structures(self): """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" - mem_index_list = [] + req_all_mem_index_list = [] ks_list = [] ke_list = [] lengths_list = [] @@ -52,7 +52,7 @@ def _init_nsa_indexing_structures(self): seq_len = self.b_seq_len[i] q_seq_len = self.b_q_seq_len[i] mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - mem_index_list.append(mem_index) + req_all_mem_index_list.append(mem_index) self.page_table_size_1[i, :seq_len] = mem_index ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 @@ -61,7 +61,7 @@ def _init_nsa_indexing_structures(self): lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) offset += seq_len - self.mem_index = torch.cat(mem_index_list, dim=0) + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d5032e72f..df045dd2d 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -71,8 +71,8 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), - k_scale.unsqueeze(1), + k_fp8, + k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] ) @@ -80,13 +80,16 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], + infer_state.req_all_mem_index + ) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -99,12 +102,6 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, ) - def get_k_float32_from_buffer(self, buffer: torch.Tensor): - k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] - k_float32 = k_fp8.float() * k_scale - return k_float32 - @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -121,8 +118,11 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) - + # TODO + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], k[:, None, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5b550ab09..df5220427 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -64,7 +64,11 @@ def _get_qkv( @override def _bind_attention(self): - super()._bind_attention() + if "triton_fp8kv" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) + else: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index a098795fb..46095bfb7 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -6,132 +6,270 @@ @triton.jit def _fwd_kernel_destindex_copy_indexer_ks( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - stride_k_fp8_bs, - stride_k_fp8_h, - stride_k_fp8_d, - stride_k_scale_bs, - stride_k_scale_h, - stride_k_scale_d, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - head_num, + K_fp8, + K_scale, + DestLoc, + O_buffer, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + stride_o_bs, + stride_o_h, + stride_o_d, BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, ): + """ + Triton kernel to copy FP8 K values and their scales to an indexed output buffer. + + This kernel reads FP8 key values (128 dims) and their float32 scale values, + then writes them to a compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The destination location for each source element is specified by DestLoc. + """ cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data - k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] - k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Load k_scale data - k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] - k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Store k_fp8 to buffer_fp8 - buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] - tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) - - # Store k_scale to buffer_scale - buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] - tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + + # Store K_fp8 to O_buffer[:, 0, :128] + # Convert fp8 to uint8 through bitcast for storage in uint8 buffer + o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d + k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) + tl.store(o_k_ptrs, k_fp8_as_uint8) + + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) + # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation + o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d + scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) + + # Store each byte of the float32 scale (little-endian) + for i in range(4): + byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) + tl.store(o_scale_ptr + i * stride_o_d, byte_val) + + return @torch.no_grad() -def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): - seq_len = mem_index.shape[0] - head_num = k_fp8.shape[1] - k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 - k_scale_dim = k_scale.shape[2] # Should be 1 +def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): + """ + Copy FP8-quantized key values and their scales to indexed locations in a buffer. + + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) + mechanism to store compressed key representations in a memory buffer. Each key + is stored with its FP8 representation (128 bytes) followed by its float32 scale + (4 bytes), for a total of 132 bytes per key. + + Args: + K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn + FP8-quantized key values + K_scale: [q_seq_len, 1] torch.float32 + Quantization scales for each key + DestLoc: [q_seq_len] torch.int32 + Destination indices in the output buffer + O_buffer: [large_size, 1, 132] torch.uint8 + Output buffer where keys and scales will be written. + Must be a uint8 tensor to allow mixed-type storage. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - assert k_fp8.shape[1] == k_scale.shape[1] - assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" - assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes - - # Reinterpret buffer as the appropriate types for storing - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - BLOCK_HEAD = triton.next_power_of_2(head_num) + Returns: + None (modifies O_buffer in-place) + + Example: + >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() + >>> k_scale = torch.randn(50, 1).cuda() + >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() + >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) + >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data + """ + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert K_scale.shape[0] == seq_len + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + grid = (seq_len,) num_warps = 1 - + _fwd_kernel_destindex_copy_indexer_ks[grid]( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - k_fp8.stride(0), - k_fp8.stride(1), - k_fp8.stride(2), - k_scale.stride(0), - k_scale.stride(1), - k_scale.stride(2), - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - head_num, - BLOCK_DMODEL=k_fp8_dim, - BLOCK_HEAD=BLOCK_HEAD, + K_fp8, + K_scale, + DestLoc, + O_buffer, + K_fp8.stride(0), + K_fp8.stride(1), + K_scale.stride(0), + K_scale.stride(1), + O_buffer.stride(0), + O_buffer.stride(1), + O_buffer.stride(2), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) return -def test(): +def test_destindex_copy_indexer_ks(): + """Test the destindex_copy_indexer_ks kernel""" import torch.nn.functional as F - - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim - seq_len = 50 # number of tokens to copy - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - - # Create test data - k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() - k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() - mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - - # Create buffer [total_tokens, heads, 132] - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Call the function - destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) - - # Verify results - for i in range(seq_len): - dest_idx = mem_index[i].item() - # Check k_fp8 part - stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) - expected_fp8 = k_fp8[i] - assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" - - # Check k_scale part - stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] - expected_scale = k_scale[i] - assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" - - print("All tests passed!") + + print("=" * 80) + print("Testing destindex_copy_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random destination indices + dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(dest_loc) + + # Create input tensors + k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create output buffer (as uint8 to allow reinterpretation) + o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + + # Run kernel + destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) + + # Extract results + k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) + + # Extract scale by reinterpreting 4 bytes as float32 + scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() + k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) + + # Verify results at destination locations + k_fp8_extracted = k_fp8_out[dest_loc] + k_scale_extracted = k_scale_out[dest_loc] + + # Check FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8.to(torch.float32), + atol=0, rtol=0 + ) + + # Check scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test edge cases + print("Testing edge cases...") + + # Test with sequential indices + dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) + + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) + scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() + k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_seq = torch.allclose( + k_fp8_out_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_out_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Edge case tests passed!") + print() + + # Test with single element + print("Testing single element...") + dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) + + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) + scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() + k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_single = torch.allclose( + k_fp8_out_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_out_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_destindex_copy_indexer_ks() \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index e97454ba2..eb22fbb8f 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -1,156 +1,309 @@ import torch + import triton import triton.language as tl -import numpy @triton.jit def _fwd_kernel_extract_indexer_ks( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - stride_k_fp8_out_bs, - stride_k_fp8_out_d, - stride_k_scale_out_bs, + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 + stride_i_bs, + stride_i_h, + stride_i_d, + stride_o_fp8_bs, + stride_o_fp8_d, + stride_o_scale_bs, BLOCK_DMODEL: tl.constexpr, ): + """ + Triton kernel to extract FP8 K values and their scales from an indexed buffer. + + This kernel is the inverse of destindex_copy_indexer_ks. It reads from a + compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The source location for each output element is specified by SrcLoc. + """ cur_index = tl.program_id(0) - - # Load the memory index - mem_idx = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] offs_d = tl.arange(0, BLOCK_DMODEL) - k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d - k_fp8_data = tl.load(k_fp8_ptrs) - - # Load k_scale data from buffer_scale[mem_idx, 0, 0] - k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d - k_scale_data = tl.load(k_scale_ptr) - - # Store k_fp8 output - k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d - tl.store(k_fp8_out_ptrs, k_fp8_data) - - # Store k_scale output - k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs - tl.store(k_scale_out_ptr, k_scale_data) + + # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + + # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d + k_fp8_as_uint8 = tl.load(i_k_ptrs) + + # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + + # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) + # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d + + # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) + byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) + byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) + byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) + + # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) + + # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) + + # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs + tl.store(o_scale_ptr, k_scale) + + return @torch.no_grad() -def extract_indexer_ks(buffer, mem_index): +def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ - Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. - + Extract FP8-quantized key values and their scales from indexed locations in a buffer. + + This function is the inverse operation of destindex_copy_indexer_ks. It's used in + the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve + compressed key representations from a memory buffer. + Args: - buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 - mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 - + I_buffer: [large_size, 1, 132] torch.uint8 + Input buffer containing packed FP8 keys and float32 scales. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales + SrcLoc: [req_size] torch.int32 or torch.int64 + Source indices to extract from the input buffer + Returns: - k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn - k_scale: Tensor of shape [seq_len] with dtype float32 + tuple containing: + - K_fp8: [req_size, 128] torch.float8_e4m3fn + FP8-quantized key values + - K_scale: [req_size] torch.float32 + Quantization scales for each key + + Example: + >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() + >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) + >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] """ - seq_len = mem_index.shape[0] - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" - - # Reinterpret buffer as the appropriate types for Triton - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - # Prepare output tensors - k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) - k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) - - BLOCK_DMODEL = 128 - grid = (seq_len,) + req_size = SrcLoc.shape[0] + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + + # Allocate output tensors + O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + + grid = (req_size,) num_warps = 1 - + _fwd_kernel_extract_indexer_ks[grid]( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - k_fp8_out.stride(0), - k_fp8_out.stride(1), - k_scale_out.stride(0), - BLOCK_DMODEL=BLOCK_DMODEL, + I_buffer, + SrcLoc, + O_fp8, + O_scale, + I_buffer.stride(0), + I_buffer.stride(1), + I_buffer.stride(2), + O_fp8.stride(0), + O_fp8.stride(1), + O_scale.stride(0), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) + + return O_fp8, O_scale - return k_fp8_out, k_scale_out - - -def test(): - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) - seq_len = 50 # number of tokens to extract - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - # Create test buffer [total_tokens, heads, 132] as uint8 - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Fill with test data - simulate what destindex_copy_indexer_ks does - test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - # Generate fp8 data by converting from float32 - test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() - test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) - test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() - - # Manually populate buffer as destindex_copy_indexer_ks would - for i in range(seq_len): - dest_idx = test_indices[i].item() - # Store fp8 data - buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) - # Store scale data (4 bytes) - need to convert float32 to bytes - scale_bytes = test_k_scale[i].cpu().numpy().tobytes() - scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) - buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) - - # Call our extraction function - extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) - - # Verify results - print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") - print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") - print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") - print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") - - # Check if extraction matches (convert fp8 to float32 for comparison) - # Use higher tolerance for fp8 due to quantization precision - fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) - scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) - - print(f"FP8 data matches: {fp8_match}") - print(f"Scale data matches: {scale_match}") - - if fp8_match and scale_match: - print("All tests passed!") - else: - print("Test failed!") - if not fp8_match: - print("First few fp8 values:") - print(f"Original: {test_k_fp8_fp32[0, :5]}") - print(f"Extracted: {extracted_fp8.float()[0, :5]}") - if not scale_match: - print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") +def test_extract_indexer_ks(): + """Test the extract_indexer_ks kernel against the copy kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + + print("=" * 80) + print("Testing extract_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random indices for writing + write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(write_indices) + + # Create input tensors + k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_original = (k_bf16_original / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create buffer and write data using destindex_copy_indexer_ks + buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) + + # Now extract the data back using extract_indexer_ks + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) + + # Verify FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8_original.to(torch.float32), + atol=0, rtol=0 + ) + + # Verify scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale_original.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) + + fp8_match_seq = torch.allclose( + k_fp8_ext_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_ext_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_ext_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + # Test with larger batch to check performance characteristics + print("Testing larger batch (performance check)...") + write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() + actual_large_len = len(write_indices_large) + k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") + k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) + + # Warm up + for _ in range(3): + _ = extract_indexer_ks(buffer_large, write_indices_large) + + # Time it + torch.cuda.synchronize() + import time + start = time.time() + for _ in range(100): + k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) + torch.cuda.synchronize() + elapsed = time.time() - start + + fp8_match_large = torch.allclose( + k_fp8_ext_large.to(torch.float32), + k_fp8_large.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_large = torch.allclose( + k_scale_ext_large, + k_scale_large.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") + print(f" Average time per call: {elapsed/100*1000:.3f} ms") + assert fp8_match_large and scale_match_large + print("✓ Large batch test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_extract_indexer_ks() From 9bbffcd27c9a78e9cb2a6666e10d03cbcf205f07 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 14:32:31 +0000 Subject: [PATCH 12/23] fix cudagraph --- lightllm/models/deepseek3_2/infer_struct.py | 168 +++++++++++++++++--- 1 file changed, 147 insertions(+), 21 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index c122c6a7e..db6e61a1c 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,8 +1,10 @@ import torch +import weakref from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): + _shared_nsa_buffers = None def __init__(self): super().__init__() @@ -14,8 +16,42 @@ def __init__(self): self.index_topk = 2048 return + @classmethod + def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): + """Get or create pre-allocated buffers for CUDA graph execution""" + if cls._shared_nsa_buffers is None: + # Pre-allocate buffers for max possible sizes + max_total_q_tokens = graph_max_batch_size * max_seq_len + max_total_tokens = graph_max_batch_size * max_seq_len + + cls._shared_nsa_buffers = [ + { + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + }, + { # Second buffer for microbatch overlap if needed + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + } + ] + return cls._shared_nsa_buffers + def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + + # Store weak reference to model for accessing graph parameters + self._model_ref = weakref.ref(model) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager @@ -29,11 +65,34 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) + # Check if we can use CUDA graph based on batch size and max_len constraints + use_cuda_graph_buffers = False + if (hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph + if use_cuda_graph_buffers: + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] + self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + else: + # Create new tensors dynamically + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') + + # Calculate actual values + self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - self.nsa_cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) + + # Compute cumulative sum with padding + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) + self.nsa_cu_seqlens_k[0] = 0 # Pre-compute NSA indexer indexing structures self._init_nsa_indexing_structures() @@ -46,22 +105,89 @@ def _init_nsa_indexing_structures(self): lengths_list = [] offset = 0 num_seq_len = self.b_req_idx.shape[0] - self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + max_seq_len = self.b_seq_len.max().item() + + # Calculate total sizes needed + total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) + total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) + + # Check if we should use CUDA graph buffers + use_cuda_graph_buffers = False + if hasattr(self, '_model_ref'): + model = self._model_ref() + if (model is not None and + hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + if use_cuda_graph_buffers: + # Use pre-allocated buffers for CUDA graph + model = self._model_ref() + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.ks = buffer['ks'][:total_q_len] + self.ke = buffer['ke'][:total_q_len] + self.lengths = buffer['lengths'][:total_q_len] + self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] + + # Zero out page_table_size_1 before filling + self.page_table_size_1.zero_() + + # Compute and copy values into the pre-allocated buffer views + ks_offset = 0 + ke_offset = 0 + lengths_offset = 0 + req_offset = 0 + seq_offset = 0 + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + + # Copy req_all_mem_index + self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index + + # Fill page_table_size_1 + self.page_table_size_1[i, :seq_len] = mem_index + + # Fill ks, ke, lengths + self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + ) + self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + ) + + ks_offset += q_seq_len + ke_offset += q_seq_len + lengths_offset += q_seq_len + req_offset += seq_len + seq_offset += seq_len + else: + # Original dynamic allocation for non-CUDA graph mode + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') - for i in range(num_seq_len): - seq_len = self.b_seq_len[i] - q_seq_len = self.b_q_seq_len[i] - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - req_all_mem_index_list.append(mem_index) - self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 - ks_list.append(ks) - ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) - offset += seq_len + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_all_mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len - self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) - self.ks = torch.cat(ks_list, dim=0) - self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file From 47dfb5580c962a5145c55a46c7d5321881e9b1df Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 26 Dec 2025 10:53:38 +0000 Subject: [PATCH 13/23] can run --- lightllm/common/basemodel/basemodel.py | 20 ++ lightllm/common/infer_utils.py | 70 +++++- .../kv_cache_mem_manager/mem_manager.py | 15 +- lightllm/models/deepseek3_2/infer_struct.py | 144 ++++++----- .../layer_infer/nsa_indexer_layer_inder.py | 87 ++++--- .../layer_infer/transformer_layer_infer.py | 33 +-- .../triton_kernel/copy_indexer_ks.py | 232 ++++++++++++++++++ .../destindex_copy_indexer_ks.py | 151 +++++------- 8 files changed, 552 insertions(+), 200 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b871..c11b68c99 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -476,6 +476,24 @@ def _prefill( ) infer_state = self._create_inferstate(model_input) + + # Capture old indexer_ks positions before they are overwritten + # This is needed for DeepSeek v3.2 to copy cached tokens' indexer_ks + old_indexer_ks_positions = [] + for i in range(infer_state.b_req_idx.shape[0]): + req_idx = infer_state.b_req_idx[i].item() + ready_cache_len = infer_state.b_ready_cache_len[i].item() + + if ready_cache_len > 0: + # Capture old positions for cached tokens + old_pos = self.req_manager.req_to_token_indexs[ + req_idx, 0:ready_cache_len + ].clone() # Clone to avoid view issues + old_indexer_ks_positions.append(old_pos) + else: + # No cached tokens for this request + old_indexer_ks_positions.append(None) + init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, b_req_idx=infer_state.b_req_idx, @@ -484,6 +502,8 @@ def _prefill( b_start_loc=model_input.b_prefill_start_loc, alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, + mem_manager=self.req_manager.mem_manager, + old_indexer_ks_positions=old_indexer_ks_positions, ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index e1b9cc383..ed3c0b73e 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -2,8 +2,17 @@ def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, b_start_loc, alloc_mem_index, max_q_seq_len + req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + b_start_loc, + alloc_mem_index, + max_q_seq_len, + mem_manager=None, + old_indexer_ks_positions=None, ): + # Step 1: Copy KV cache for NEW tokens (existing logic) copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, b_req_idx=b_req_idx, @@ -13,3 +22,62 @@ def init_req_to_token_indexes( memindex=alloc_mem_index, max_q_seq_len=max_q_seq_len, ) + + # Step 2: Copy indexer_ks for CACHED tokens (DeepSeek v3.2 specific) + # This ensures consistency between KV cache and indexer_ks buffers + # when prefix cache is hit + if ( + mem_manager is not None + and hasattr(mem_manager, "indexer_ks_mem_manager") + and old_indexer_ks_positions is not None + ): + + _copy_cached_indexer_ks_to_new_positions( + req_to_token_indexs=req_to_token_indexs, + b_req_idx=b_req_idx, + b_ready_cache_len=b_ready_cache_len, + mem_manager=mem_manager, + old_indexer_ks_positions=old_indexer_ks_positions, + ) + + +def _copy_cached_indexer_ks_to_new_positions( + req_to_token_indexs, + b_req_idx, + b_ready_cache_len, + mem_manager, + old_indexer_ks_positions, +): + """ + Copy cached tokens' indexer_ks from old positions to new positions. + + This function is called after copy_kv_index_to_req_prefill() has updated + req_to_token_indexs to point to new contiguous positions. We need to copy + indexer_ks data to match the KV cache layout. + + For each layer and each request with cached tokens: + - Copy indexer_ks data from old positions to new positions + - This ensures consistency when using extract_indexer_ks later + """ + from lightllm.models.deepseek3_2.triton_kernel.copy_indexer_ks import copy_indexer_ks + + # Get number of layers from indexer_ks_mem_manager + num_layers = len(mem_manager.indexer_ks_mem_manager.kv_buffer) + indexer_buffer = mem_manager.indexer_ks_mem_manager.kv_buffer + + for layer_idx in range(num_layers): + for i in range(b_req_idx.shape[0]): + req_idx = b_req_idx[i].item() + ready_cache_len = b_ready_cache_len[i].item() + old_positions = old_indexer_ks_positions[i] + + if ready_cache_len > 0 and old_positions is not None: + # New positions after copy_kv_index_to_req_prefill + new_positions = req_to_token_indexs[req_idx, 0:ready_cache_len] + + # Copy indexer_ks: old_positions -> new_positions + copy_indexer_ks( + buffer=indexer_buffer[layer_idx], + src_loc=old_positions, + dest_loc=new_positions, + ) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 2940d74e2..7d5e2af04 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,7 +26,9 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + def __init__( + self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False + ): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -93,6 +95,17 @@ def profile_size(self, mem_fraction): available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) + + # Ensure size is at least a minimum positive value to avoid torch.arange errors + MIN_SIZE = 1024 # Minimum 1024 tokens + if self.size < MIN_SIZE: + logger.warning( + f"Insufficient memory for KV cache. Available: {available_memory:.2f} GB, " + f"but calculated size is {self.size} tokens. Using minimum size {MIN_SIZE} tokens instead. " + f"Consider reducing model size, using fewer GPUs, or increasing mem_fraction." + ) + self.size = MIN_SIZE + if world_size > 1: tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") dist.all_reduce(tensor, op=dist.ReduceOp.MIN) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index db6e61a1c..2f8aa7562 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -3,6 +3,7 @@ from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager + class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): _shared_nsa_buffers = None @@ -23,35 +24,35 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): # Pre-allocate buffers for max possible sizes max_total_q_tokens = graph_max_batch_size * max_seq_len max_total_tokens = graph_max_batch_size * max_seq_len - + cls._shared_nsa_buffers = [ { - 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), - 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), - 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), - 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), }, { # Second buffer for microbatch overlap if needed - 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), - 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), - 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), - 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), - } + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), + }, ] return cls._shared_nsa_buffers def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) - + # Store weak reference to model for accessing graph parameters self._model_ref = weakref.ref(model) - + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager @@ -60,36 +61,39 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # b_ready_cache_len is already set in basemodel.py for prefill pass else: - # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len - # since b_q_seq_len represents the new tokens being processed + # In decode mode, b_ready_cache_len is set by the router/scheduler + # based on actual prefix cache hits. If it's None (no prefix cache enabled), + # it should be 0, not computed from b_seq_len - b_q_seq_len if self.b_ready_cache_len is None: - self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) # Check if we can use CUDA graph based on batch size and max_len constraints use_cuda_graph_buffers = False - if (hasattr(model, 'graph_max_batch_size') and - hasattr(model, 'graph_max_len_in_batch') and - self.batch_size <= model.graph_max_batch_size and - self.max_len_in_batch <= model.graph_max_len_in_batch): + if ( + hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): use_cuda_graph_buffers = True - + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - + # Use views into pre-allocated buffers - self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] - self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] + self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] else: # Create new tensors dynamically - self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') - self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') - + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") + # Calculate actual values self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - + # Compute cumulative sum with padding torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 @@ -106,65 +110,68 @@ def _init_nsa_indexing_structures(self): offset = 0 num_seq_len = self.b_req_idx.shape[0] max_seq_len = self.b_seq_len.max().item() - + # Calculate total sizes needed total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) - + # Check if we should use CUDA graph buffers use_cuda_graph_buffers = False - if hasattr(self, '_model_ref'): + if hasattr(self, "_model_ref"): model = self._model_ref() - if (model is not None and - hasattr(model, 'graph_max_batch_size') and - hasattr(model, 'graph_max_len_in_batch') and - self.batch_size <= model.graph_max_batch_size and - self.max_len_in_batch <= model.graph_max_len_in_batch): + if ( + model is not None + and hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): use_cuda_graph_buffers = True - + if use_cuda_graph_buffers: # Use pre-allocated buffers for CUDA graph model = self._model_ref() buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - + # Use views into pre-allocated buffers - self.ks = buffer['ks'][:total_q_len] - self.ke = buffer['ke'][:total_q_len] - self.lengths = buffer['lengths'][:total_q_len] - self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] - self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] - + self.ks = buffer["ks"][:total_q_len] + self.ke = buffer["ke"][:total_q_len] + self.lengths = buffer["lengths"][:total_q_len] + self.page_table_size_1 = buffer["page_table_size_1"][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] + # Zero out page_table_size_1 before filling self.page_table_size_1.zero_() - + # Compute and copy values into the pre-allocated buffer views ks_offset = 0 ke_offset = 0 lengths_offset = 0 req_offset = 0 seq_offset = 0 - + for i in range(num_seq_len): seq_len = self.b_seq_len[i].item() q_seq_len = self.b_q_seq_len[i].item() - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - + req_idx = self.b_req_idx[i].item() + mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] + # Copy req_all_mem_index - self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index - + self.req_all_mem_index[req_offset : req_offset + seq_len] = mem_index + # Fill page_table_size_1 self.page_table_size_1[i, :seq_len] = mem_index - + # Fill ks, ke, lengths - self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) - self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( - seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + self.ks[ks_offset : ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset : ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device="cuda" ) - self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( - seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + self.lengths[lengths_offset : lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda" ) - + ks_offset += q_seq_len ke_offset += q_seq_len lengths_offset += q_seq_len @@ -172,22 +179,23 @@ def _init_nsa_indexing_structures(self): seq_offset += seq_len else: # Original dynamic allocation for non-CUDA graph mode - self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device="cuda") for i in range(num_seq_len): seq_len = self.b_seq_len[i].item() q_seq_len = self.b_q_seq_len[i].item() - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_idx = self.b_req_idx[i].item() + mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] req_all_mem_index_list.append(mem_index) self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks = torch.zeros(q_seq_len, dtype=torch.int, device="cuda") + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device="cuda") + offset + 1 ks_list.append(ks) ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda")) offset += seq_len self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file + self.lengths = torch.cat(lengths_list, dim=0) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index df045dd2d..2f4421e74 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) + class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): super().__init__() @@ -38,13 +39,20 @@ def __init__(self, layer_idx, network_config, mode=[]): return - def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + def ref_fp8_mqa_logits( + self, + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + cost_only: bool = False, + ): seq_len_kv = kv.shape[0] if cost_only: start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) - end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) count_ones_per_row = (end - start).clamp(min=0) return count_ones_per_row.sum() @@ -52,29 +60,31 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T q = q.float() k = k.float() - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost - def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + def get_indices( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: NSAIndexerWeight, + ) -> torch.Tensor: q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8, - k_scale, - infer_state.mem_index, - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -87,34 +97,47 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], - infer_state.req_all_mem_index + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index ) - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - + # Get actual sequence length from q (which comes from q_lora) + # This may differ from ks.shape[0] during certain operations + actual_seq_len = q.shape[0] + + # ks, ke, lengths, and weights should all match actual_seq_len + # Slice them if they don't match + if ks.shape[0] != actual_seq_len: + ks = ks[:actual_seq_len] + ke = ke[:actual_seq_len] + lengths = lengths[:actual_seq_len] + weights = weights[:actual_seq_len] + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + return fast_topk_transform_fused( - score=logits, - lengths=lengths, - page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.cu_seqlens_q, + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, topk=self.index_topk, ) - @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 from sgl_kernel import hadamard_transform hidden_size = x.size(-1) - assert ( - hidden_size & (hidden_size - 1) - ) == 0, "Hidden size must be a power of 2 for Hadamard transform." - return hadamard_transform(x, scale=hidden_size**-0.5) - - def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size ** -0.5) + + def _get_q_k_bf16( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: NSAIndexerWeight, + ): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) @@ -123,11 +146,13 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) + # Slice position_cos and position_sin to match actual token length + actual_seq_len = q.shape[0] rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], k[:, None, : self.qk_rope_head_dim], - infer_state.position_cos, - infer_state.position_sin, + infer_state.position_cos[:actual_seq_len], + infer_state.position_sin[:actual_seq_len], ) q = self._rotate_activation(q) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index df5220427..cf748bcdb 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -19,11 +19,7 @@ def __init__(self, layer_num, network_config, mode=[]): self.index_topk = network_config["index_topk"] super().__init__(layer_num, network_config, mode) - self.indexer = NSAIndexerInfer( - layer_idx=self.layer_num_, - network_config=self.network_config_, - mode=mode - ) + self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) self.topk_indices = None return @@ -41,6 +37,9 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + # Process all tokens for indexer + # Note: Prefix cache slicing optimization is disabled due to batch structure + # mismatch issues with fast_topk_transform_fused kernel self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) q = layer_weight.q_b_proj_.mm(q) @@ -81,12 +80,12 @@ def _nsa_context_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: - + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, + q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, @@ -95,7 +94,11 @@ def _nsa_context_attention_kernel( return mla_out def _nsa_token_attention_kernel( - self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + self, + q, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -104,16 +107,16 @@ def _nsa_token_attention_kernel( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, + q=q_rope, + k_cache=k_rope, v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, ) - return o_tensor \ No newline at end of file + return o_tensor diff --git a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py new file mode 100644 index 000000000..93cf463eb --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py @@ -0,0 +1,232 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_copy_indexer_ks( + buffer, # [large_size, 1, 132] uint8 + src_loc, # [copy_len] int32/int64 - source indices + dest_loc, # [copy_len] int32/int64 - destination indices + stride_bs, + stride_h, + stride_d, + BLOCK_KV: tl.constexpr, # = 128 (FP8 data) + BLOCK_SCALE: tl.constexpr, # = 4 (scale data) +): + """ + Triton kernel to copy indexer_ks data from source locations to destination locations. + + This kernel copies 132-byte indexer_ks entries (128 bytes FP8 key + 4 bytes float32 scale) + from source positions to destination positions within the same buffer. + + Args: + buffer: Shared buffer containing indexer_ks data [large_size, 1, 132] uint8 + src_loc: Source indices to copy from [copy_len] + dest_loc: Destination indices to copy to [copy_len] + stride_bs, stride_h, stride_d: Strides for the buffer + BLOCK_KV: Size of FP8 key data (128 bytes) + BLOCK_SCALE: Size of scale data (4 bytes) + """ + cur_index = tl.program_id(0) + offs_kv = tl.arange(0, BLOCK_KV) + offs_scale = tl.arange(0, BLOCK_SCALE) + + # Load source and destination indices + src_index = tl.load(src_loc + cur_index).to(tl.int64) + dest_index = tl.load(dest_loc + cur_index).to(tl.int64) + + # Copy FP8 key data (128 bytes) + src_kv_ptrs = buffer + src_index * stride_bs + stride_d * offs_kv + dest_kv_ptrs = buffer + dest_index * stride_bs + stride_d * offs_kv + kv_data = tl.load(src_kv_ptrs) + tl.store(dest_kv_ptrs, kv_data) + + # Copy scale data (4 bytes at offset 128) + src_scale_base = buffer + src_index * stride_bs + BLOCK_KV * stride_d + dest_scale_base = buffer + dest_index * stride_bs + BLOCK_KV * stride_d + scale_data = tl.load(src_scale_base + offs_scale * stride_d) + tl.store(dest_scale_base + offs_scale * stride_d, scale_data) + + return + + +@torch.no_grad() +def copy_indexer_ks( + buffer: torch.Tensor, + src_loc: torch.Tensor, + dest_loc: torch.Tensor, +): + """ + Copy indexer_ks data from source positions to destination positions. + + This function is used to copy cached tokens' indexer_ks data to new locations + after prefix cache matching. It ensures that the indexer_ks buffer stays + consistent with the KV cache buffer. + + Args: + buffer: [large_size, 1, 132] torch.uint8 + Buffer containing indexer_ks data (same buffer for src and dest) + src_loc: [copy_len] torch.int32 or torch.int64 + Source indices in buffer (old positions) + dest_loc: [copy_len] torch.int32 or torch.int64 + Destination indices in buffer (new positions) + + Returns: + None (modifies buffer in-place) + + Example: + >>> buffer = torch.zeros((1024, 1, 132), dtype=torch.uint8).cuda() + >>> old_pos = torch.tensor([100, 101, 102], dtype=torch.int32).cuda() + >>> new_pos = torch.tensor([200, 201, 202], dtype=torch.int32).cuda() + >>> copy_indexer_ks(buffer, old_pos, new_pos) + # Data from positions [100, 101, 102] is now copied to [200, 201, 202] + """ + copy_len = src_loc.shape[0] + block_kv = 128 # FP8 key data size + block_scale = 4 # Float32 scale size + + assert ( + src_loc.shape[0] == dest_loc.shape[0] + ), f"src_loc and dest_loc must have same length: {src_loc.shape[0]} != {dest_loc.shape[0]}" + assert ( + buffer.shape[2] == block_kv + block_scale + ), f"Expected buffer last dim={block_kv + block_scale}, got {buffer.shape[2]}" + assert buffer.dtype == torch.uint8, f"Expected buffer dtype=uint8, got {buffer.dtype}" + + grid = (copy_len,) + num_warps = 1 + + _fwd_kernel_copy_indexer_ks[grid]( + buffer, + src_loc, + dest_loc, + buffer.stride(0), + buffer.stride(1), + buffer.stride(2), + BLOCK_KV=block_kv, + BLOCK_SCALE=block_scale, + num_warps=num_warps, + num_stages=1, + ) + + return + + +def test_copy_indexer_ks(): + """Test the copy_indexer_ks kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks + + print("=" * 80) + print("Testing copy_indexer_ks") + print("=" * 80) + + # Test parameters + cached_len = 20 + buffer_size = 1024 + head_dim = 128 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create indexer_ks data + k_bf16 = torch.randn((cached_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + + # Write to old positions + old_positions = torch.arange(100, 100 + cached_len, dtype=torch.int32, device="cuda") + buffer = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8, k_scale, old_positions, buffer) + + # Copy to new positions + new_positions = torch.arange(200, 200 + cached_len, dtype=torch.int32, device="cuda") + copy_indexer_ks(buffer, old_positions, new_positions) + + # Verify data at new positions matches original + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, new_positions) + + fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) + + scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16, dim=-1).mean() + + print(f"Cached tokens: {cached_len}, Head dim: {head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + old_pos_seq = torch.arange(20, dtype=torch.int32, device="cuda") + new_pos_seq = torch.arange(200, 220, dtype=torch.int32, device="cuda") + + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + + buffer_seq = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, old_pos_seq, buffer_seq) + copy_indexer_ks(buffer_seq, old_pos_seq, new_pos_seq) + + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, new_pos_seq) + + fp8_match_seq = torch.allclose(k_fp8_ext_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) + scale_match_seq = torch.allclose(k_scale_ext_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + old_pos_single = torch.tensor([42], dtype=torch.int32, device="cuda") + new_pos_single = torch.tensor([424], dtype=torch.int32, device="cuda") + + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = ( + (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + ) + + buffer_single = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, old_pos_single, buffer_single) + copy_indexer_ks(buffer_single, old_pos_single, new_pos_single) + + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, new_pos_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 + ) + scale_match_single = torch.allclose(k_scale_ext_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) + + +if __name__ == "__main__": + test_copy_indexer_ks() diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index 46095bfb7..8faf3cdea 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -21,55 +21,57 @@ def _fwd_kernel_destindex_copy_indexer_ks( ): """ Triton kernel to copy FP8 K values and their scales to an indexed output buffer. - + This kernel reads FP8 key values (128 dims) and their float32 scale values, then writes them to a compact buffer format where each entry contains: - Bytes 0-127: FP8 key values (128 bytes) - Bytes 128-131: Float32 scale (4 bytes) - + The destination location for each source element is specified by DestLoc. """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) - + # Load destination index for this thread dest_index = tl.load(DestLoc + cur_index).to(tl.int64) - + # Load K_fp8 (128 values) and K_scale (1 value) from source k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d k_fp8 = tl.load(k_fp8_ptrs) - + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) - + # Store K_fp8 to O_buffer[:, 0, :128] # Convert fp8 to uint8 through bitcast for storage in uint8 buffer o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) tl.store(o_k_ptrs, k_fp8_as_uint8) - + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) - + # Store each byte of the float32 scale (little-endian) for i in range(4): byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) tl.store(o_scale_ptr + i * stride_o_d, byte_val) - + return @torch.no_grad() -def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): +def destindex_copy_indexer_ks( + K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor +): """ Copy FP8-quantized key values and their scales to indexed locations in a buffer. - + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to store compressed key representations in a memory buffer. Each key is stored with its FP8 representation (128 bytes) followed by its float32 scale (4 bytes), for a total of 132 bytes per key. - + Args: K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn FP8-quantized key values @@ -84,7 +86,7 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo Returns: None (modifies O_buffer in-place) - + Example: >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() >>> k_scale = torch.randn(50, 1).cuda() @@ -95,14 +97,21 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo """ seq_len = DestLoc.shape[0] head_dim = K_fp8.shape[1] - + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" - assert K_scale.shape[0] == seq_len + + # Handle cases where tensor lengths don't match (e.g., during prefix cache) + actual_seq_len = min(K_scale.shape[0], seq_len) + if actual_seq_len < seq_len: + K_fp8 = K_fp8[:actual_seq_len] + K_scale = K_scale[:actual_seq_len] + DestLoc = DestLoc[:actual_seq_len] + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" - - grid = (seq_len,) + + grid = (actual_seq_len,) num_warps = 1 - + _fwd_kernel_destindex_copy_indexer_ks[grid]( K_fp8, K_scale, @@ -125,151 +134,125 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo def test_destindex_copy_indexer_ks(): """Test the destindex_copy_indexer_ks kernel""" import torch.nn.functional as F - + print("=" * 80) print("Testing destindex_copy_indexer_ks") print("=" * 80) - + # Test parameters q_seq_len = 50 head_dim = 128 large_size = 1024 dtype = torch.bfloat16 fp8_type = torch.float8_e4m3fn - + # Create random destination indices dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() actual_seq_len = len(dest_loc) - + # Create input tensors k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - + # Quantize to FP8 k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + # Create output buffer (as uint8 to allow reinterpretation) o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - + # Run kernel destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) - + # Extract results k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) - + # Extract scale by reinterpreting 4 bytes as float32 scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) - + # Verify results at destination locations k_fp8_extracted = k_fp8_out[dest_loc] k_scale_extracted = k_scale_out[dest_loc] - + # Check FP8 values match - fp8_match = torch.allclose( - k_fp8_extracted.to(torch.float32), - k_fp8.to(torch.float32), - atol=0, rtol=0 - ) - + fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) + # Check scales match - scale_match = torch.allclose( - k_scale_extracted, - k_scale.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) + # Check dequantized values k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() - + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") print(f" FP8 values match: {fp8_match}") print(f" Scale values match: {scale_match}") print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - + assert fp8_match, "FP8 values do not match!" assert scale_match, "Scale values do not match!" assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - + print("✓ Basic test passed!") print() - + # Test edge cases print("Testing edge cases...") - + # Test with sequential indices dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) - + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_seq = torch.allclose( - k_fp8_out_seq.to(torch.float32), - k_fp8_seq.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_seq = torch.allclose( - k_scale_out_seq, - k_scale_seq.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + + fp8_match_seq = torch.allclose(k_fp8_out_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) + scale_match_seq = torch.allclose(k_scale_out_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") assert fp8_match_seq and scale_match_seq print("✓ Edge case tests passed!") print() - + # Test with single element print("Testing single element...") dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8_single = ( + (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + ) + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) - + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) - + fp8_match_single = torch.allclose( - k_fp8_out_single.to(torch.float32), - k_fp8_single.to(torch.float32), - atol=0, rtol=0 + k_fp8_out_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 ) - scale_match_single = torch.allclose( - k_scale_out_single, - k_scale_single.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + scale_match_single = torch.allclose(k_scale_out_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") assert fp8_match_single and scale_match_single print("✓ Single element test passed!") print() - + print("=" * 80) print("All tests passed successfully! ✓") print("=" * 80) if __name__ == "__main__": - test_destindex_copy_indexer_ks() \ No newline at end of file + test_destindex_copy_indexer_ks() From 61a3e0b2674120fafd25853a5831c24097eef76c Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 2 Feb 2026 11:45:47 +0000 Subject: [PATCH 14/23] abstract NSA attention into backend framework Add NSA (Native Sparse Attention) backend abstraction following the existing MLA pattern. This enables future support for multiple NSA implementations (flashmla_sparse, fa3, tilelang, aiter). - Add attention framework from origin/main with NSA extensions - Create NsaFlashMlaSparseAttBackend with prefill/decode states - Extend AttControl with nsa_prefill/nsa_decode params - Add factory functions get_nsa_*_att_backend_class() - Refactor DeepSeek V3.2 to use NSA backend - Add missing envs_utils functions for compatibility --- .../common/basemodel/attention/__init__.py | 5 + .../common/basemodel/attention/base_att.py | 5 + .../basemodel/attention/create_utils.py | 47 +++++ .../basemodel/attention/nsa/__init__.py | 13 ++ .../attention/nsa/flashmla_sparse.py | 172 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 81 ++++++--- 6 files changed, 299 insertions(+), 24 deletions(-) create mode 100644 lightllm/common/basemodel/attention/nsa/__init__.py create mode 100644 lightllm/common/basemodel/attention/nsa/flashmla_sparse.py diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df54549..0eea52cc8 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -10,9 +10,14 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + from .create_utils import ( get_prefill_att_backend_class, get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_nsa_prefill_att_backend_class, + get_nsa_decode_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca8..1286a46ec 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,11 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + # nsa (native sparse attention) 专用传参项 + nsa_prefill: bool = False + nsa_prefill_dict: Dict = None + nsa_decode: bool = False + nsa_decode_dict: Dict = None @dataclass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13..dd3802895 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -17,6 +17,9 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + logger = init_logger(__name__) # Backend class mappings by data type @@ -46,6 +49,14 @@ }, } +# NSA (Native Sparse Attention) backend mappings +nsa_data_type_to_backend = { + "None": { + "flashmla_sparse": NsaFlashMlaSparseAttBackend, + # Future backends: "fa3", "tilelang", "aiter" + }, +} + def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -105,3 +116,39 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + + +def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + """Get NSA prefill attention backend class. + + Args: + backend_str: Backend name, currently only "flashmla_sparse" is supported. + Future options: "fa3", "tilelang", "aiter" + + Returns: + NSA attention backend class + """ + # NSA currently only supports "None" dtype (no quantization) + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] + + +def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + """Get NSA decode attention backend class. + + Args: + backend_str: Backend name, currently only "flashmla_sparse" is supported. + Future options: "fa3", "tilelang", "aiter" + + Returns: + NSA attention backend class + """ + # NSA currently only supports "None" dtype (no quantization) + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py new file mode 100644 index 000000000..11a1ebfdc --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -0,0 +1,13 @@ +"""NSA (Native Sparse Attention) backend implementations.""" + +from .flashmla_sparse import ( + NsaFlashMlaSparseAttBackend, + NsaFlashMlaSparsePrefillAttState, + NsaFlashMlaSparseDecodeAttState, +) + +__all__ = [ + "NsaFlashMlaSparseAttBackend", + "NsaFlashMlaSparsePrefillAttState", + "NsaFlashMlaSparseDecodeAttState", +] diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py new file mode 100644 index 000000000..8e5249999 --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -0,0 +1,172 @@ +"""NSA FlashMLA-sparse attention backend implementation. + +This backend uses sgl_kernel's flash_mla_sparse_fwd for prefill +and flash_attn_with_kvcache for decode with sparse indices. +""" + +import dataclasses +import torch +from typing import Tuple, TYPE_CHECKING + +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaSparseAttBackend(BaseAttBackend): + """NSA backend using FlashMLA sparse kernels from sgl_kernel.""" + + def __init__(self, model): + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": + return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState": + return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState): + """Prefill attention state for NSA using flash_mla_sparse_fwd.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + """Execute NSA prefill attention. + + Args: + q: Query tensor [total_tokens, num_heads, head_dim] - already projected with k_b_proj + k: KV buffer tensor from memory manager + v: Not used for NSA (pass None) + att_control: Must have nsa_prefill=True and nsa_prefill_dict with: + - topk_indices: Sparse attention indices [total_tokens, topk] + - softmax_scale: Attention softmax scale + - kv_lora_rank: d_v dimension for MLA + + Returns: + Output tensor [total_tokens, num_heads, kv_lora_rank] + """ + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + + return self._nsa_prefill_att(q=q, kv=k, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + nsa_dict = att_control.nsa_prefill_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + # flash_mla_sparse_fwd expects indices with shape [total_tokens, 1, topk] + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): + """Decode attention state for NSA using flash_attn_with_kvcache.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + """Execute NSA decode attention. + + Args: + q: Tuple of (q_nope, q_rope) tensors + k: KV buffer tensor from memory manager + v: Not used for NSA (pass None) + att_control: Must have nsa_decode=True and nsa_decode_dict with: + - topk_indices: Page table for sparse attention [batch, topk] + - nsa_cache_seqlens: Cache sequence lengths for NSA + - nsa_cu_seqlens_k: Cumulative sequence lengths for NSA + - softmax_scale: Attention softmax scale + - kv_lora_rank: d_v dimension for MLA + - qk_rope_head_dim: Rope head dimension + + Returns: + Output tensor + """ + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + nsa_cache_seqlens = nsa_dict["nsa_cache_seqlens"] + nsa_cu_seqlens_k = nsa_dict["nsa_cu_seqlens_k"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] + + q_nope, q_rope = q + + # Extract k_rope and kv_nope from the KV buffer + k_rope = kv[:, :, -qk_rope_head_dim:].reshape(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].reshape(-1, 1, 1, kv_lora_rank) + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=nsa_cache_seqlens, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=nsa_cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + return o_tensor diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index cf748bcdb..1abb749a8 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -2,8 +2,6 @@ from typing import override import torch -from sgl_kernel.flash_mla import flash_mla_sparse_fwd -from sgl_kernel.flash_attn import flash_attn_with_kvcache from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer @@ -12,6 +10,8 @@ from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -21,8 +21,19 @@ def __init__(self, layer_num, network_config, mode=[]): self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) self.topk_indices = None + + # Initialize NSA attention backend (singleton, lazy initialization) + self._nsa_backend_class = get_nsa_prefill_att_backend_class() + self._nsa_backend = None return + def _get_nsa_backend(self): + """Get or create the NSA backend (lazy initialization).""" + if self._nsa_backend is None: + # NSA backend doesn't require model reference for basic operations + self._nsa_backend = self._nsa_backend_class(model=None) + return self._nsa_backend + @override def _get_qkv( self, @@ -80,16 +91,30 @@ def _nsa_context_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: - + # Model-specific q projection (uses layer weights) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - mla_out, _, _ = flash_mla_sparse_fwd( + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "topk_indices": self.topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + }, + ) + + # Create prefill state and execute attention + nsa_backend = self._get_nsa_backend() + prefill_state = nsa_backend.create_att_prefill_state(infer_state) + prefill_state.init_state() + mla_out = prefill_state.prefill_att( q=q_all, - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=self.topk_indices.unsqueeze(1), - sm_scale=self.softmax_scale, - d_v=self.kv_lora_rank, + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, ) return mla_out @@ -100,23 +125,31 @@ def _nsa_token_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ): + # Model-specific q projection (uses layer weights) q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=self.softmax_scale, - causal=True, + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "topk_indices": self.topk_indices, + "nsa_cache_seqlens": infer_state.nsa_cache_seqlens, + "nsa_cu_seqlens_k": infer_state.nsa_cu_seqlens_k, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + "qk_rope_head_dim": self.qk_rope_head_dim, + }, + ) + + # Create decode state and execute attention + nsa_backend = self._get_nsa_backend() + decode_state = nsa_backend.create_att_decode_state(infer_state) + decode_state.init_state() + o_tensor = decode_state.decode_att( + q=(q_nope, q_rope), + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, ) return o_tensor From e5a4ea9fce86393d23a0a6b2dbc4060f8794b77a Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 08:04:39 +0000 Subject: [PATCH 15/23] add support --- .../basemodel/attention/create_utils.py | 23 - .../attention/nsa/flashmla_sparse.py | 41 -- lightllm/common/basemodel/basemodel.py | 20 - lightllm/common/infer_utils.py | 63 +-- .../deepseek2_mem_manager.py | 4 +- .../kv_cache_mem_manager/mem_manager.py | 30 +- lightllm/models/deepseek3_2/infer_struct.py | 196 +++---- .../layer_infer/nsa_indexer_layer_inder.py | 6 +- .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_weights/nsa_indexer_layer_weight.py | 15 +- lightllm/models/deepseek3_2/mem_manager.py | 43 +- lightllm/models/deepseek3_2/model.py | 4 +- .../triton_kernel/copy_indexer_ks.py | 232 -------- lightllm/server/api_cli.py | 2 +- lightllm/server/api_openai.py | 33 ++ lightllm/server/core/objs/sampling_params.py | 53 +- lightllm/server/function_call_parser.py | 503 +++++++++++++++++- .../tool_chat_template_deepseekv32.jinjia | 301 +++++++---- test/test_api/test_gsmk.py | 265 +++++++++ 19 files changed, 1173 insertions(+), 679 deletions(-) delete mode 100644 lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py create mode 100644 test/test_api/test_gsmk.py diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index dd3802895..e3bf81dae 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -16,8 +16,6 @@ from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend - -# NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend logger = init_logger(__name__) @@ -49,7 +47,6 @@ }, } -# NSA (Native Sparse Attention) backend mappings nsa_data_type_to_backend = { "None": { "flashmla_sparse": NsaFlashMlaSparseAttBackend, @@ -119,16 +116,6 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - """Get NSA prefill attention backend class. - - Args: - backend_str: Backend name, currently only "flashmla_sparse" is supported. - Future options: "fa3", "tilelang", "aiter" - - Returns: - NSA attention backend class - """ - # NSA currently only supports "None" dtype (no quantization) llm_dtype = "None" if backend_str not in nsa_data_type_to_backend[llm_dtype]: logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") @@ -137,16 +124,6 @@ def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> B def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - """Get NSA decode attention backend class. - - Args: - backend_str: Backend name, currently only "flashmla_sparse" is supported. - Future options: "fa3", "tilelang", "aiter" - - Returns: - NSA attention backend class - """ - # NSA currently only supports "None" dtype (no quantization) llm_dtype = "None" if backend_str not in nsa_data_type_to_backend[llm_dtype]: logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 8e5249999..3eec98f05 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -1,9 +1,3 @@ -"""NSA FlashMLA-sparse attention backend implementation. - -This backend uses sgl_kernel's flash_mla_sparse_fwd for prefill -and flash_attn_with_kvcache for decode with sparse indices. -""" - import dataclasses import torch from typing import Tuple, TYPE_CHECKING @@ -16,8 +10,6 @@ class NsaFlashMlaSparseAttBackend(BaseAttBackend): - """NSA backend using FlashMLA sparse kernels from sgl_kernel.""" - def __init__(self, model): super().__init__(model=model) @@ -47,20 +39,6 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - """Execute NSA prefill attention. - - Args: - q: Query tensor [total_tokens, num_heads, head_dim] - already projected with k_b_proj - k: KV buffer tensor from memory manager - v: Not used for NSA (pass None) - att_control: Must have nsa_prefill=True and nsa_prefill_dict with: - - topk_indices: Sparse attention indices [total_tokens, topk] - - softmax_scale: Attention softmax scale - - kv_lora_rank: d_v dimension for MLA - - Returns: - Output tensor [total_tokens, num_heads, kv_lora_rank] - """ assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" @@ -79,7 +57,6 @@ def _nsa_prefill_att( softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] - # flash_mla_sparse_fwd expects indices with shape [total_tokens, 1, topk] if topk_indices.ndim == 2: topk_indices = topk_indices.unsqueeze(1) @@ -95,7 +72,6 @@ def _nsa_prefill_att( @dataclasses.dataclass class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): - """Decode attention state for NSA using flash_attn_with_kvcache.""" cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None @@ -112,23 +88,6 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - """Execute NSA decode attention. - - Args: - q: Tuple of (q_nope, q_rope) tensors - k: KV buffer tensor from memory manager - v: Not used for NSA (pass None) - att_control: Must have nsa_decode=True and nsa_decode_dict with: - - topk_indices: Page table for sparse attention [batch, topk] - - nsa_cache_seqlens: Cache sequence lengths for NSA - - nsa_cu_seqlens_k: Cumulative sequence lengths for NSA - - softmax_scale: Attention softmax scale - - kv_lora_rank: d_v dimension for MLA - - qk_rope_head_dim: Rope head dimension - - Returns: - Output tensor - """ assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c11b68c99..5c1d2b871 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -476,24 +476,6 @@ def _prefill( ) infer_state = self._create_inferstate(model_input) - - # Capture old indexer_ks positions before they are overwritten - # This is needed for DeepSeek v3.2 to copy cached tokens' indexer_ks - old_indexer_ks_positions = [] - for i in range(infer_state.b_req_idx.shape[0]): - req_idx = infer_state.b_req_idx[i].item() - ready_cache_len = infer_state.b_ready_cache_len[i].item() - - if ready_cache_len > 0: - # Capture old positions for cached tokens - old_pos = self.req_manager.req_to_token_indexs[ - req_idx, 0:ready_cache_len - ].clone() # Clone to avoid view issues - old_indexer_ks_positions.append(old_pos) - else: - # No cached tokens for this request - old_indexer_ks_positions.append(None) - init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, b_req_idx=infer_state.b_req_idx, @@ -502,8 +484,6 @@ def _prefill( b_start_loc=model_input.b_prefill_start_loc, alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, - mem_manager=self.req_manager.mem_manager, - old_indexer_ks_positions=old_indexer_ks_positions, ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index ed3c0b73e..26cf973be 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,3 +1,4 @@ +import torch from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req_prefill @@ -9,10 +10,7 @@ def init_req_to_token_indexes( b_start_loc, alloc_mem_index, max_q_seq_len, - mem_manager=None, - old_indexer_ks_positions=None, ): - # Step 1: Copy KV cache for NEW tokens (existing logic) copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, b_req_idx=b_req_idx, @@ -22,62 +20,3 @@ def init_req_to_token_indexes( memindex=alloc_mem_index, max_q_seq_len=max_q_seq_len, ) - - # Step 2: Copy indexer_ks for CACHED tokens (DeepSeek v3.2 specific) - # This ensures consistency between KV cache and indexer_ks buffers - # when prefix cache is hit - if ( - mem_manager is not None - and hasattr(mem_manager, "indexer_ks_mem_manager") - and old_indexer_ks_positions is not None - ): - - _copy_cached_indexer_ks_to_new_positions( - req_to_token_indexs=req_to_token_indexs, - b_req_idx=b_req_idx, - b_ready_cache_len=b_ready_cache_len, - mem_manager=mem_manager, - old_indexer_ks_positions=old_indexer_ks_positions, - ) - - -def _copy_cached_indexer_ks_to_new_positions( - req_to_token_indexs, - b_req_idx, - b_ready_cache_len, - mem_manager, - old_indexer_ks_positions, -): - """ - Copy cached tokens' indexer_ks from old positions to new positions. - - This function is called after copy_kv_index_to_req_prefill() has updated - req_to_token_indexs to point to new contiguous positions. We need to copy - indexer_ks data to match the KV cache layout. - - For each layer and each request with cached tokens: - - Copy indexer_ks data from old positions to new positions - - This ensures consistency when using extract_indexer_ks later - """ - from lightllm.models.deepseek3_2.triton_kernel.copy_indexer_ks import copy_indexer_ks - - # Get number of layers from indexer_ks_mem_manager - num_layers = len(mem_manager.indexer_ks_mem_manager.kv_buffer) - indexer_buffer = mem_manager.indexer_ks_mem_manager.kv_buffer - - for layer_idx in range(num_layers): - for i in range(b_req_idx.shape[0]): - req_idx = b_req_idx[i].item() - ready_cache_len = b_ready_cache_len[i].item() - old_positions = old_indexer_ks_positions[i] - - if ready_cache_len > 0 and old_positions is not None: - # New positions after copy_kv_index_to_req_prefill - new_positions = req_to_token_indexs[req_idx, 0:ready_cache_len] - - # Copy indexer_ks: old_positions -> new_positions - copy_indexer_ks( - buffer=indexer_buffer[layer_idx], - src_loc=old_positions, - dest_loc=new_positions, - ) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index ad54b3935..3d93e1b07 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -15,8 +15,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 7d5e2af04..1203cbdec 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,9 +26,7 @@ class MemoryManager: - def __init__( - self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False - ): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -50,16 +48,15 @@ def __init__( self.can_use_mem_size = self.size - if not is_sub_mem_manager: - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -95,17 +92,6 @@ def profile_size(self, mem_fraction): available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) - - # Ensure size is at least a minimum positive value to avoid torch.arange errors - MIN_SIZE = 1024 # Minimum 1024 tokens - if self.size < MIN_SIZE: - logger.warning( - f"Insufficient memory for KV cache. Available: {available_memory:.2f} GB, " - f"but calculated size is {self.size} tokens. Using minimum size {MIN_SIZE} tokens instead. " - f"Consider reducing model size, using fewer GPUs, or increasing mem_fraction." - ) - self.size = MIN_SIZE - if world_size > 1: tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") dist.all_reduce(tensor, op=dist.ReduceOp.MIN) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 2f8aa7562..e0cca499b 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,10 +1,10 @@ import torch import weakref -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): +class Deepseek3_2FlashAttentionStateInfo(Deepseek2InferStateInfo): _shared_nsa_buffers = None def __init__(self): @@ -21,7 +21,6 @@ def __init__(self): def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): """Get or create pre-allocated buffers for CUDA graph execution""" if cls._shared_nsa_buffers is None: - # Pre-allocate buffers for max possible sizes max_total_q_tokens = graph_max_batch_size * max_seq_len max_total_tokens = graph_max_batch_size * max_seq_len @@ -35,7 +34,7 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), }, - { # Second buffer for microbatch overlap if needed + { "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), @@ -47,155 +46,124 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_nsa_buffers - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def _check_use_cuda_graph_buffers(self): + if hasattr(self, "_model_ref"): + model = self._model_ref() + if ( + model is not None + and hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): + return True + return False + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) - # Store weak reference to model for accessing graph parameters self._model_ref = weakref.ref(model) assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) - self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer - # Ensure b_ready_cache_len is set for both prefill and decode modes if self.is_prefill: - # b_ready_cache_len is already set in basemodel.py for prefill pass else: - # In decode mode, b_ready_cache_len is set by the router/scheduler - # based on actual prefix cache hits. If it's None (no prefix cache enabled), - # it should be 0, not computed from b_seq_len - b_q_seq_len if self.b_ready_cache_len is None: self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) - # Check if we can use CUDA graph based on batch size and max_len constraints - use_cuda_graph_buffers = False - if ( - hasattr(model, "graph_max_batch_size") - and hasattr(model, "graph_max_len_in_batch") - and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch - ): - use_cuda_graph_buffers = True + use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() - # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - - # Use views into pre-allocated buffers self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] else: - # Create new tensors dynamically self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") - # Calculate actual values self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - # Compute cumulative sum with padding torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 - # Pre-compute NSA indexer indexing structures self._init_nsa_indexing_structures() def _init_nsa_indexing_structures(self): - """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" - req_all_mem_index_list = [] - ks_list = [] - ke_list = [] - lengths_list = [] - offset = 0 - num_seq_len = self.b_req_idx.shape[0] - max_seq_len = self.b_seq_len.max().item() - - # Calculate total sizes needed - total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) - total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) - - # Check if we should use CUDA graph buffers - use_cuda_graph_buffers = False - if hasattr(self, "_model_ref"): - model = self._model_ref() - if ( - model is not None - and hasattr(model, "graph_max_batch_size") - and hasattr(model, "graph_max_len_in_batch") - and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch - ): - use_cuda_graph_buffers = True + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer. + + Fully vectorized: eliminates per-request .item() CPU-GPU syncs. + """ + b_seq_len = self.b_seq_len + b_q_seq_len = self.b_q_seq_len + b_req_idx = self.b_req_idx + num_seq = b_req_idx.shape[0] + device = b_seq_len.device + + # Only 3 scalar syncs needed (for tensor shapes) + max_seq_len = b_seq_len.max().item() + total_q_len = b_q_seq_len.sum().item() + total_seq_len = b_seq_len.sum().item() + + # --- page_table_size_1 and req_all_mem_index (vectorized gather) --- + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + # page_table_size_1: [batch, max_seq_len] zero-padded memory indices + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) + page_table[valid_mask] = all_rows[valid_mask].int() + + # req_all_mem_index: flattened valid memory indices across all requests + req_all_mem_index = all_rows[valid_mask] + + # --- ks, ke, lengths (vectorized computation) --- + # Cumulative seq_len offsets: [0, seq_len[0], seq_len[0]+seq_len[1], ...] + cum_seq = torch.cumsum(b_seq_len, dim=0) + seq_offsets = torch.zeros_like(cum_seq) + seq_offsets[1:] = cum_seq[:-1] + + # Expand per-request values to per-token using repeat_interleave + req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) + + # Token position within each request's q_seq + cum_q = torch.cumsum(b_q_seq_len, dim=0) + q_offsets = torch.zeros_like(cum_q) + q_offsets[1:] = cum_q[:-1] + token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] + + # ks[t] = seq_offset of request owning token t + # ke[t] = seq_offset + position_in_q + 1 + # lengths[t] = seq_len - q_seq_len + position_in_q + 1 + ks = seq_offsets[req_indices].int() + ke = (seq_offsets[req_indices] + token_in_req + 1).int() + lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() + + # --- Assign results (CUDA graph buffer or new tensors) --- + use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() if use_cuda_graph_buffers: - # Use pre-allocated buffers for CUDA graph model = self._model_ref() buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - # Use views into pre-allocated buffers self.ks = buffer["ks"][:total_q_len] self.ke = buffer["ke"][:total_q_len] self.lengths = buffer["lengths"][:total_q_len] - self.page_table_size_1 = buffer["page_table_size_1"][:num_seq_len, :max_seq_len] + self.page_table_size_1 = buffer["page_table_size_1"][:num_seq, :max_seq_len] self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] - # Zero out page_table_size_1 before filling - self.page_table_size_1.zero_() - - # Compute and copy values into the pre-allocated buffer views - ks_offset = 0 - ke_offset = 0 - lengths_offset = 0 - req_offset = 0 - seq_offset = 0 - - for i in range(num_seq_len): - seq_len = self.b_seq_len[i].item() - q_seq_len = self.b_q_seq_len[i].item() - req_idx = self.b_req_idx[i].item() - mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] - - # Copy req_all_mem_index - self.req_all_mem_index[req_offset : req_offset + seq_len] = mem_index - - # Fill page_table_size_1 - self.page_table_size_1[i, :seq_len] = mem_index - - # Fill ks, ke, lengths - self.ks[ks_offset : ks_offset + q_seq_len].fill_(seq_offset) - self.ke[ke_offset : ke_offset + q_seq_len] = torch.arange( - seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device="cuda" - ) - self.lengths[lengths_offset : lengths_offset + q_seq_len] = torch.arange( - seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda" - ) - - ks_offset += q_seq_len - ke_offset += q_seq_len - lengths_offset += q_seq_len - req_offset += seq_len - seq_offset += seq_len + self.ks.copy_(ks) + self.ke.copy_(ke) + self.lengths.copy_(lengths) + self.page_table_size_1.copy_(page_table) + self.req_all_mem_index.copy_(req_all_mem_index) else: - # Original dynamic allocation for non-CUDA graph mode - self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device="cuda") - - for i in range(num_seq_len): - seq_len = self.b_seq_len[i].item() - q_seq_len = self.b_q_seq_len[i].item() - req_idx = self.b_req_idx[i].item() - mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] - req_all_mem_index_list.append(mem_index) - self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device="cuda") + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device="cuda") + offset + 1 - ks_list.append(ks) - ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda")) - offset += seq_len - - self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) - self.ks = torch.cat(ks_list, dim=0) - self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) + self.ks = ks + self.ke = ke + self.lengths = lengths + self.page_table_size_1 = page_table + self.req_all_mem_index = req_all_mem_index diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 2f4421e74..3855bf590 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -8,10 +8,8 @@ from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -84,7 +82,7 @@ def get_indices( k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_] ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -97,7 +95,7 @@ def get_indices( # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index + infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index ) # Get actual sequence length from q (which comes from q_lora) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 1abb749a8..bc8bb9c6b 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,4 +1,3 @@ -from functools import partial from typing import override import torch @@ -8,7 +7,7 @@ from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class @@ -73,17 +72,7 @@ def _get_qkv( return q, cache_kv @override - def _bind_attention(self): - if "triton_fp8kv" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) - else: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - - self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) - self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) - pass - - def _nsa_context_attention_kernel( + def _context_attention_kernel( self, q: torch.Tensor, kv, @@ -118,7 +107,8 @@ def _nsa_context_attention_kernel( ) return mla_out - def _nsa_token_attention_kernel( + @override + def _token_attention_kernel( self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 47e0bfdac..9ccfbe97e 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -3,7 +3,7 @@ import torch from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight class NSAIndexerWeight(TransformerLayerWeight): @@ -14,7 +14,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" - + self.wq_b_proj_ = ROWMMWeight( weight_name=f"{prefix}.wq_b.weight", data_type=self.data_type_, @@ -33,15 +33,16 @@ def _init_weight(self): tp_rank=0, tp_world_size=1, ) - self.k_norm_ = NormWeight( - f"{prefix}.k_norm.weight", - torch.float32, - bias_name=f"{prefix}.k_norm.bias" + self.k_norm_ = LayerNormWeight( + dim=self.network_config_["index_head_dim"], + weight_name=f"{prefix}.k_norm.weight", + data_type=torch.float32, + bias_name=f"{prefix}.k_norm.bias", ) self.weights_proj_ = ROWMMWeight( weight_name=f"{prefix}.weights_proj.weight", data_type=self.data_type_, - quant_cfg=None, + quant_cfg=None, layer_num=self.layer_num_, name="weights_proj", tp_rank=0, diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index a70c76273..8017a84ad 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,22 +1,41 @@ -from typing import List from typing_extensions import override import torch -from lightllm.common.mem_manager import MemoryManager -from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + +class IndexerKSBuffer: + """Lightweight buffer holder for NSA indexer keys+scales. + + Shares token indices with the parent MemoryManager — does NOT have its + own allocator. Only stores the per-layer kv_buffer tensor. + """ + + def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): + self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + class Deepseek3_2MemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) - self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) - return + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) @override def get_cell_size(self): return super().get_cell_size() + 132 - + + @override + def _free_buffers(self): + super()._free_buffers() + self.indexer_ks_buffer = None + + @override + def resize_mem(self, new_size): + super().resize_mem(new_size) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) + + class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 8f1ba85cf..d25cbd378 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,6 +5,8 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager + + @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class @@ -44,4 +46,4 @@ def _init_mem_manager(self): layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) - return \ No newline at end of file + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py deleted file mode 100644 index 93cf463eb..000000000 --- a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py +++ /dev/null @@ -1,232 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_copy_indexer_ks( - buffer, # [large_size, 1, 132] uint8 - src_loc, # [copy_len] int32/int64 - source indices - dest_loc, # [copy_len] int32/int64 - destination indices - stride_bs, - stride_h, - stride_d, - BLOCK_KV: tl.constexpr, # = 128 (FP8 data) - BLOCK_SCALE: tl.constexpr, # = 4 (scale data) -): - """ - Triton kernel to copy indexer_ks data from source locations to destination locations. - - This kernel copies 132-byte indexer_ks entries (128 bytes FP8 key + 4 bytes float32 scale) - from source positions to destination positions within the same buffer. - - Args: - buffer: Shared buffer containing indexer_ks data [large_size, 1, 132] uint8 - src_loc: Source indices to copy from [copy_len] - dest_loc: Destination indices to copy to [copy_len] - stride_bs, stride_h, stride_d: Strides for the buffer - BLOCK_KV: Size of FP8 key data (128 bytes) - BLOCK_SCALE: Size of scale data (4 bytes) - """ - cur_index = tl.program_id(0) - offs_kv = tl.arange(0, BLOCK_KV) - offs_scale = tl.arange(0, BLOCK_SCALE) - - # Load source and destination indices - src_index = tl.load(src_loc + cur_index).to(tl.int64) - dest_index = tl.load(dest_loc + cur_index).to(tl.int64) - - # Copy FP8 key data (128 bytes) - src_kv_ptrs = buffer + src_index * stride_bs + stride_d * offs_kv - dest_kv_ptrs = buffer + dest_index * stride_bs + stride_d * offs_kv - kv_data = tl.load(src_kv_ptrs) - tl.store(dest_kv_ptrs, kv_data) - - # Copy scale data (4 bytes at offset 128) - src_scale_base = buffer + src_index * stride_bs + BLOCK_KV * stride_d - dest_scale_base = buffer + dest_index * stride_bs + BLOCK_KV * stride_d - scale_data = tl.load(src_scale_base + offs_scale * stride_d) - tl.store(dest_scale_base + offs_scale * stride_d, scale_data) - - return - - -@torch.no_grad() -def copy_indexer_ks( - buffer: torch.Tensor, - src_loc: torch.Tensor, - dest_loc: torch.Tensor, -): - """ - Copy indexer_ks data from source positions to destination positions. - - This function is used to copy cached tokens' indexer_ks data to new locations - after prefix cache matching. It ensures that the indexer_ks buffer stays - consistent with the KV cache buffer. - - Args: - buffer: [large_size, 1, 132] torch.uint8 - Buffer containing indexer_ks data (same buffer for src and dest) - src_loc: [copy_len] torch.int32 or torch.int64 - Source indices in buffer (old positions) - dest_loc: [copy_len] torch.int32 or torch.int64 - Destination indices in buffer (new positions) - - Returns: - None (modifies buffer in-place) - - Example: - >>> buffer = torch.zeros((1024, 1, 132), dtype=torch.uint8).cuda() - >>> old_pos = torch.tensor([100, 101, 102], dtype=torch.int32).cuda() - >>> new_pos = torch.tensor([200, 201, 202], dtype=torch.int32).cuda() - >>> copy_indexer_ks(buffer, old_pos, new_pos) - # Data from positions [100, 101, 102] is now copied to [200, 201, 202] - """ - copy_len = src_loc.shape[0] - block_kv = 128 # FP8 key data size - block_scale = 4 # Float32 scale size - - assert ( - src_loc.shape[0] == dest_loc.shape[0] - ), f"src_loc and dest_loc must have same length: {src_loc.shape[0]} != {dest_loc.shape[0]}" - assert ( - buffer.shape[2] == block_kv + block_scale - ), f"Expected buffer last dim={block_kv + block_scale}, got {buffer.shape[2]}" - assert buffer.dtype == torch.uint8, f"Expected buffer dtype=uint8, got {buffer.dtype}" - - grid = (copy_len,) - num_warps = 1 - - _fwd_kernel_copy_indexer_ks[grid]( - buffer, - src_loc, - dest_loc, - buffer.stride(0), - buffer.stride(1), - buffer.stride(2), - BLOCK_KV=block_kv, - BLOCK_SCALE=block_scale, - num_warps=num_warps, - num_stages=1, - ) - - return - - -def test_copy_indexer_ks(): - """Test the copy_indexer_ks kernel""" - import torch.nn.functional as F - from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks - from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks - - print("=" * 80) - print("Testing copy_indexer_ks") - print("=" * 80) - - # Test parameters - cached_len = 20 - buffer_size = 1024 - head_dim = 128 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create indexer_ks data - k_bf16 = torch.randn((cached_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - # Write to old positions - old_positions = torch.arange(100, 100 + cached_len, dtype=torch.int32, device="cuda") - buffer = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8, k_scale, old_positions, buffer) - - # Copy to new positions - new_positions = torch.arange(200, 200 + cached_len, dtype=torch.int32, device="cuda") - copy_indexer_ks(buffer, old_positions, new_positions) - - # Verify data at new positions matches original - k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, new_positions) - - fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) - - scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) - - # Check dequantized values - k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16, dim=-1).mean() - - print(f"Cached tokens: {cached_len}, Head dim: {head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test with sequential indices - print("Testing sequential indices...") - old_pos_seq = torch.arange(20, dtype=torch.int32, device="cuda") - new_pos_seq = torch.arange(200, 220, dtype=torch.int32, device="cuda") - - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - buffer_seq = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, old_pos_seq, buffer_seq) - copy_indexer_ks(buffer_seq, old_pos_seq, new_pos_seq) - - k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, new_pos_seq) - - fp8_match_seq = torch.allclose(k_fp8_ext_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) - scale_match_seq = torch.allclose(k_scale_ext_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Sequential test passed!") - print() - - # Test with single element - print("Testing single element...") - old_pos_single = torch.tensor([42], dtype=torch.int32, device="cuda") - new_pos_single = torch.tensor([424], dtype=torch.int32, device="cuda") - - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = ( - (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - ) - - buffer_single = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, old_pos_single, buffer_single) - copy_indexer_ks(buffer_single, old_pos_single, new_pos_single) - - k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, new_pos_single) - - fp8_match_single = torch.allclose( - k_fp8_ext_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 - ) - scale_match_single = torch.allclose(k_scale_ext_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - - -if __name__ == "__main__": - test_copy_indexer_ks() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 73b9bad4a..1661a3b87 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d94..928d840c8 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -463,6 +463,39 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") # Additional usage chunk + # Finalize any pending tool calls (e.g., DSML format last invoke) + if request.tool_choice != "none" and request.tools and parser_dict: + for _idx, _parser in parser_dict.items(): + _, finalize_calls = _parser.finalize_stream() + history_tool_calls_cnt = _get_history_tool_calls_cnt(request) + for call_item in finalize_calls: + if call_item.name: + tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt) + function_name = call_item.name + else: + tool_call_id = None + function_name = None + tool_call = ToolCall( + id=tool_call_id, + index=getattr(call_item, "tool_index", None), + function=FunctionResponse( + name=function_name, + arguments=call_item.parameters, + ), + ) + choice_data = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), + finish_reason="tool_calls", + ) + chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( prompt_tokens=prompt_tokens, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a8..d7fd35961 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -331,31 +331,38 @@ class SamplingParams(ctypes.Structure): _top_p: float = 1.0 _top_k: int = -1 # -1 is for all + @staticmethod + def _get(kwargs, key, default): + """Get value from kwargs, falling back to default when value is None or missing.""" + val = kwargs.get(key) + return val if val is not None else default + def init(self, tokenizer, **kwargs): super().__init__() - self.best_of = kwargs.get("best_of", 1) - self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) - self.ignore_eos = kwargs.get("ignore_eos", False) - self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) - self.min_new_tokens = kwargs.get("min_new_tokens", 1) - self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = kwargs.get("group_request_id", -1) - self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) - - self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) - self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) - - self.add_special_tokens = kwargs.get("add_special_tokens", True) - self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) - self.print_eos_token = kwargs.get("print_eos_token", False) + _get = SamplingParams._get + self.best_of = _get(kwargs, "best_of", 1) + self.n = _get(kwargs, "n", self.best_of) + self.do_sample = _get(kwargs, "do_sample", SamplingParams._do_sample) + self.presence_penalty = _get(kwargs, "presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = _get(kwargs, "frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = _get(kwargs, "repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = _get(kwargs, "temperature", SamplingParams._temperature) + self.top_p = _get(kwargs, "top_p", SamplingParams._top_p) + self.top_k = _get(kwargs, "top_k", SamplingParams._top_k) + self.ignore_eos = _get(kwargs, "ignore_eos", False) + self.image_max_patch_num = _get(kwargs, "image_max_patch_num", -1) + self.max_new_tokens = _get(kwargs, "max_new_tokens", 16) + self.min_new_tokens = _get(kwargs, "min_new_tokens", 1) + self.input_penalty = _get(kwargs, "input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = _get(kwargs, "group_request_id", -1) + self.suggested_dp_index = _get(kwargs, "suggested_dp_index", -1) + + self.skip_special_tokens = _get(kwargs, "skip_special_tokens", SKIP_SPECIAL_TOKENS) + self.disable_prompt_cache = _get(kwargs, "disable_prompt_cache", False) + + self.add_special_tokens = _get(kwargs, "add_special_tokens", True) + self.add_spaces_between_special_tokens = _get(kwargs, "add_spaces_between_special_tokens", True) + self.print_eos_token = _get(kwargs, "print_eos_token", False) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1..c3faf21e7 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -29,7 +29,15 @@ logger = logging.getLogger(__name__) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>", "[TOOL_CALLS]", "<|tool▁calls▁begin|>"] +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", + "<|tool▁calls▁begin|>", + "<|DSML|function_calls>", +] class ToolCallItem(BaseModel): @@ -1443,6 +1451,482 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class DeepSeekV32Detector(BaseFormatDetector): + """ + Detector for DeepSeek V3.2 model function call format (DSML). + + DeepSeek V3.2 uses a new DSML (DeepSeek Markup Language) format for tool calls, + which is XML-like rather than JSON-based. + + Format Structure: + ``` + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">杭州 + <|DSML|parameter name="date" string="true">2024-01-16 + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">北京 + <|DSML|parameter name="date" string="true">2024-01-16 + ``` + + Key Components: + - Tool Calls Section: Starts with `<|DSML|function_calls>` + - Individual Invoke: `<|DSML|invoke name="function_name">` + - Parameters: `<|DSML|parameter name="param_name" string="true">value` + - Parameter types are inferred from the tool schema for proper JSON serialization + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 + """ + + def __init__(self): + super().__init__() + self.dsml_token = "|DSML|" + self.bot_token = "<|DSML|function_calls>" + self.eot_token = "" # DSML format has no explicit end token + self.invoke_prefix = '<|DSML|invoke name="' + self.parameter_prefix = '<|DSML|parameter name="' + + # Regex for complete parsing + self.invoke_regex = re.compile( + r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)(?=<|DSML|invoke|$)', + re.DOTALL, + ) + # Captures: (param_name, is_string, value) + self.parameter_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)(?=<|DSML|parameter|<|DSML|invoke|$)', + re.DOTALL, + ) + + # Streaming state + self._last_arguments = "" + self._current_invoke_text = "" + self._invoke_count = 0 + self._param_count_in_invoke = 0 + self._accumulated_params: Dict[str, str] = {} + self._json_started = False + self._tools_schema: Optional[Dict[str, Dict]] = None + self._tool_indices: Optional[Dict[str, int]] = None + self._current_func_name: Optional[str] = None + self._in_tool_call_sequence = False # Set True once bot_token seen + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a DeepSeek V3.2 DSML format tool call.""" + return self.bot_token in text + + def _get_param_type(self, func_name: str, param_name: str, tools: List[Tool]) -> str: + """Get the JSON Schema type of a parameter from the tool definition.""" + if self._tools_schema is None: + self._tools_schema = {} + for tool in tools: + if tool.function.name and tool.function.parameters: + props = tool.function.parameters.get("properties", {}) + self._tools_schema[tool.function.name] = props + + func_schema = self._tools_schema.get(func_name, {}) + param_schema = func_schema.get(param_name, {}) + return param_schema.get("type", "string") + + def _convert_param_value(self, value: str, is_string_attr: str, param_type: str) -> Any: + """Convert a raw parameter value string to the appropriate Python type. + + Args: + value: The raw string value from the DSML parameter tag. + is_string_attr: The "string" attribute from DSML ("true" or "false"). + If "true", the value is treated as a raw string. + If "false", the value is parsed based on param_type or JSON. + param_type: The JSON Schema type from the tool definition (fallback). + """ + value = value.strip() + if value.lower() == "null": + return None + + # Use DSML string attribute as primary signal + if is_string_attr == "true": + return value + + # string="false" - parse based on schema type or attempt JSON + param_type = param_type.lower() + if param_type in ("integer", "int"): + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ("number", "float"): + try: + val = float(value) + # Only coerce to int if it's actually an integer string + if "." not in value and "e" not in value.lower(): + return int(value) + return val + except (ValueError, TypeError, OverflowError): + return value + elif param_type in ("boolean", "bool"): + lower = value.lower() + if lower in ("true", "1"): + return True + elif lower in ("false", "0"): + return False + else: + logger.warning(f"Unexpected boolean value: {value!r}, treating as string") + return value + elif param_type in ("object", "array"): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Unknown type with string="false" - try JSON parse, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def _parse_invoke_params(self, invoke_content: str, func_name: str, tools: List[Tool]) -> Dict: + """Parse all parameters from an invoke block content.""" + params = {} + for param_name, is_string_attr, param_value in self.parameter_regex.findall(invoke_content): + param_name = param_name.strip() + param_value = param_value.strip() + param_type = self._get_param_type(func_name, param_name, tools) + params[param_name] = self._convert_param_value(param_value, is_string_attr, param_type) + return params + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses DSML tool calls in the provided text. + """ + if self.bot_token not in text: + return StreamingParseResult(normal_text=text, calls=[]) + + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx > 0 else "" + tool_section = text[idx:] + + tool_indices = self._get_tool_indices(tools) + calls = [] + + try: + for func_name, invoke_content in self.invoke_regex.findall(tool_section): + func_name = func_name.strip() + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + params = self._parse_invoke_params(invoke_content, func_name, tools) + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(params, ensure_ascii=False), + ) + ) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in DeepSeekV32 detect_and_parse: {e}") + return StreamingParseResult(normal_text=text) + + def finalize_streaming(self, tools: List[Tool]) -> StreamingParseResult: + """Finalize the last pending tool call when generation ends (EOS). + + The DSML format has no explicit end token, so the last invoke's last + parameter may remain unconfirmed. This method should be called when + the stream ends to close any open JSON and emit remaining parameters. + """ + if not self.current_tool_name_sent or self.current_tool_id < 0: + return StreamingParseResult() + + calls: List[ToolCallItem] = [] + current_text = self._buffer + + try: + # Find current invoke text + invoke_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.invoke_prefix, search_start) + if pos == -1: + break + invoke_positions.append(pos) + search_start = pos + len(self.invoke_prefix) + + if self._invoke_count < len(invoke_positions): + invoke_start = invoke_positions[self._invoke_count] + invoke_text = current_text[invoke_start:] + + name_content_start = len(self.invoke_prefix) + name_end = invoke_text.find('">', name_content_start) + if name_end != -1: + func_name = invoke_text[name_content_start:name_end].strip() + invoke_body = invoke_text[name_end + 2 :] + + # Parse all remaining params (including the last unconfirmed one) + param_matches = list(self.parameter_regex.finditer(invoke_body)) + for i in range(self._param_count_in_invoke, len(param_matches)): + match = param_matches[i] + param_name = match.group(1).strip() + is_string_attr = match.group(2) + param_value = match.group(3).strip() + + param_type = self._get_param_type(func_name, param_name, tools) + converted_value = self._convert_param_value(param_value, is_string_attr, param_type) + serialized_value = json.dumps(converted_value, ensure_ascii=False) + + if not self._json_started: + json_fragment = "{" + f'"{param_name}": {serialized_value}' + self._json_started = True + else: + json_fragment = f', "{param_name}": {serialized_value}' + + self._accumulated_params[param_name] = converted_value + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += json_fragment + + # Close the JSON object + if self._json_started: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += "}" + elif self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="{}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = "{}" + + # Update prev_tool_call_arr + if self.current_tool_id < len(self.prev_tool_call_arr): + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params + + # Reset state + self._invoke_count += 1 + self.current_tool_id += 1 + self.current_tool_name_sent = False + self._json_started = False + self._accumulated_params = {} + self._buffer = "" + + return StreamingParseResult(normal_text="", calls=calls) + except Exception as e: + logger.error(f"Error in DeepSeekV32 finalize_streaming: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for DeepSeek V3.2 DSML tool calls. + + The DSML format streams line-by-line with invoke/parameter tokens. + We accumulate parameters and only emit JSON fragments when a parameter's + value is confirmed complete (by seeing the next parameter/invoke boundary). + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have any DSML content + if not self._in_tool_call_sequence: + if not self.has_tool_call(current_text): + # Check for partial start token + if self._ends_with_partial_token(current_text, self.bot_token): + return StreamingParseResult() + self._buffer = "" + return StreamingParseResult(normal_text=new_text) + self._in_tool_call_sequence = True + + if self._tool_indices is None: + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Find all invoke starts in current buffer + invoke_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.invoke_prefix, search_start) + if pos == -1: + break + invoke_positions.append(pos) + search_start = pos + len(self.invoke_prefix) + + if not invoke_positions: + # Have bot_token but no invoke yet - keep buffering + return StreamingParseResult() + + # Process only the current (latest) invoke block + current_invoke_idx = self._invoke_count + if current_invoke_idx >= len(invoke_positions): + # All invokes already processed, keep buffering for new ones + return StreamingParseResult() + + invoke_start = invoke_positions[current_invoke_idx] + # Whether the current invoke is bounded by a next invoke + invoke_is_bounded = current_invoke_idx + 1 < len(invoke_positions) + if invoke_is_bounded: + invoke_end = invoke_positions[current_invoke_idx + 1] + else: + invoke_end = len(current_text) + + invoke_text = current_text[invoke_start:invoke_end] + + # Extract function name + name_start = invoke_text.find(self.invoke_prefix) + if name_start == -1: + return StreamingParseResult() + + name_content_start = name_start + len(self.invoke_prefix) + name_end = invoke_text.find('">', name_content_start) + if name_end == -1: + # Function name not complete yet + return StreamingParseResult() + + func_name = invoke_text[name_content_start:name_end].strip() + + # Initialize state for this tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send tool name if not sent yet + if not self.current_tool_name_sent: + if func_name and func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + self._current_func_name = func_name + self._accumulated_params = {} + self._param_count_in_invoke = 0 + self._json_started = False + return StreamingParseResult(calls=calls) + return StreamingParseResult() + + # Parse parameters from the invoke block content + invoke_body = invoke_text[name_end + 2 :] # after '">' + + # Find all parameter starts within this invoke body + param_positions = [] + ps = 0 + while True: + pp = invoke_body.find(self.parameter_prefix, ps) + if pp == -1: + break + param_positions.append(pp) + ps = pp + len(self.parameter_prefix) + + # A parameter is "confirmed" when the next parameter/invoke boundary is visible, + # meaning the parameter's value won't grow further. + # For the last parameter in the invoke body, it's only confirmed if + # the invoke itself is bounded by a next invoke. + confirmed_count = 0 + for pi in range(len(param_positions)): + if pi + 1 < len(param_positions): + confirmed_count += 1 + elif invoke_is_bounded: + confirmed_count += 1 + + # Only emit newly confirmed parameters + if confirmed_count > self._param_count_in_invoke: + param_matches = list(self.parameter_regex.finditer(invoke_body)) + for i in range(self._param_count_in_invoke, min(confirmed_count, len(param_matches))): + match = param_matches[i] + param_name = match.group(1).strip() + is_string_attr = match.group(2) + param_value = match.group(3).strip() + + param_type = self._get_param_type(func_name, param_name, tools) + converted_value = self._convert_param_value(param_value, is_string_attr, param_type) + serialized_value = json.dumps(converted_value, ensure_ascii=False) + + if not self._json_started: + json_fragment = "{" + f'"{param_name}": {serialized_value}' + self._json_started = True + else: + json_fragment = f', "{param_name}": {serialized_value}' + + self._accumulated_params[param_name] = converted_value + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += json_fragment + + self._param_count_in_invoke = confirmed_count + + # Check if next invoke has started (meaning current one is complete) + if invoke_is_bounded: + # Current invoke is complete, close JSON and advance + if self._json_started: + close_fragment = "}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=close_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += close_fragment + else: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="{}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = "{}" + + # Update prev_tool_call_arr + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params + + # Advance to next invoke, prune consumed buffer content + # Reset _invoke_count to 0 since buffer positions are now relative + self._buffer = current_text[invoke_end:] + self._invoke_count = 0 + self.current_tool_id += 1 + self.current_tool_name_sent = False + self._last_arguments = "" + self._accumulated_params = {} + self._param_count_in_invoke = 0 + self._json_started = False + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in DeepSeekV32 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1455,6 +1939,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "deepseekv32": DeepSeekV32Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, @@ -1535,3 +2020,19 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: final_normal_text = sp_result.normal_text return final_normal_text, final_calls + + def finalize_stream(self) -> Tuple[str, list[ToolCallItem]]: + """Finalize streaming when generation ends. + + For detectors that lack an explicit end-of-tool-call token (like DSML), + this closes any pending tool call JSON. For other detectors, this is a no-op. + + Returns: + A tuple of (normal_text, calls) like parse_stream_chunk. + """ + if not self.tools: + return "", [] + if hasattr(self.detector, "finalize_streaming"): + sp_result = self.detector.finalize_streaming(self.tools) + return sp_result.normal_text, sp_result.calls + return "", [] diff --git a/test/chat_template/tool_chat_template_deepseekv32.jinjia b/test/chat_template/tool_chat_template_deepseekv32.jinjia index b6d239dce..7bb0fc375 100644 --- a/test/chat_template/tool_chat_template_deepseekv32.jinjia +++ b/test/chat_template/tool_chat_template_deepseekv32.jinjia @@ -1,101 +1,202 @@ -{% if not add_generation_prompt is defined %} - {% set add_generation_prompt = false %} -{% endif %} -{% if not thinking is defined %} - {% set thinking = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %} -{%- for message in messages %} - {%- if message['role'] == 'system' %} - {%- if ns.is_first_sp %} - {% set ns.system_prompt = ns.system_prompt + message['content'] %} - {% set ns.is_first_sp = false %} - {%- else %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} - {%- endif %} - {% set ns.is_only_sys = true %} - {%- endif %} -{%- endfor %} - -{% if tools is defined and tools is not none %} - {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} - {% for tool in tools %} - {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} - {% endfor %} - {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} -{% endif %} - -{{ bos_token }}{{ ns.system_prompt }} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- set ns.is_tool = false -%} - {%- set ns.is_first = false -%} - {%- set ns.is_last_user = true -%} - {{'<|User|>' + message['content']}} - {%- endif %} - {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} - {%- if ns.is_last_user or ns.is_only_sys %} - {{'<|Assistant|>'}} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_first = false %} - {%- set ns.is_tool = false -%} - {%- for tool in message['tool_calls'] %} - {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %} - {%- if not ns.is_first %} - {%- if message['content'] is none %} - {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- else %} - {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- set ns.is_first = true -%} - {%- else %} - {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- endfor %} - {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} - {%- endif %} - {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} - {%- if ns.is_last_user %} - {{'<|Assistant|>'}} - {%- if message['prefix'] is defined and message['prefix'] and thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {%- endif %} - {%- if message['prefix'] is defined and message['prefix'] %} - {%- set ns.is_prefix = true -%} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- if ns.is_tool %} - {{message['content'] + '<|end▁of▁sentence|>'}} - {%- set ns.is_tool = false -%} - {%- else %} - {%- set content = message['content'] -%} - {%- if '' in content %} - {%- set content = content.split('', 1)[1] -%} - {%- endif %} - {{content + '<|end▁of▁sentence|>'}} - {%- endif %} - {%- endif %} - {%- if message['role'] == 'tool' %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_tool = true -%} - {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} - {%- endif %} - {%- if message['role'] != 'system' %} - {% set ns.is_only_sys = false %} - {%- endif %} +{#- ============================================================================ + DeepSeek-V3.2 DSML Chat Template + Converted from encoding_dsv32.py encode_messages function. + Uses DSML (DeepSeek Markup Language) format for tool calls. + ============================================================================ -#} +{%- set bos_token = "<|begin▁of▁sentence|>" -%} +{%- set eos_token = "<|end▁of▁sentence|>" -%} +{%- set thinking_start_token = "" -%} +{%- set thinking_end_token = "" -%} +{%- set dsml_token = "|DSML|" -%} + +{%- set system_msg_template = "{content}" -%} +{%- set user_msg_template = "<|User|>{content}<|Assistant|>" -%} +{%- set assistant_msg_template = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" -%} +{%- set thinking_template = "{reasoning_content}" -%} +{%- set tool_call_template = "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>" -%} +{%- set tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n{dsml_token}function_calls>" -%} +{%- set tool_output_template = "\n{content}" -%} + +{%- set TOOLS_SYSTEM_TEMPLATE -%} +## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{{ dsml_token }}function_calls>" block like the following as part of your reply to the user: + +<{{ dsml_token }}function_calls> +<{{ dsml_token }}invoke name="$FUNCTION_NAME"> +<{{ dsml_token }}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{{ dsml_token }}parameter> +... +{{ dsml_token }}invoke> +<{{ dsml_token }}invoke name="$FUNCTION_NAME2"> +... +{{ dsml_token }}invoke> +{{ dsml_token }}function_calls> + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +Here are the functions available in JSONSchema format: +{tool_schemas} +{%- endset -%} + +{%- if thinking_mode is not defined -%} + {%- set thinking_mode = "thinking" -%} +{%- endif -%} +{%- if drop_thinking is not defined -%} + {%- set drop_thinking = true -%} +{%- endif -%} +{%- if add_default_bos_token is not defined -%} + {%- set add_default_bos_token = true -%} +{%- endif -%} + +{#- Macro: encode_arguments_to_dsml -#} +{%- macro encode_arguments_to_dsml(arguments) -%} + {%- set ns = namespace(P_dsml_strs=[]) -%} + {%- if arguments is mapping -%} + {%- for k, v in arguments.items() -%} + {%- if v is string -%} + {%- set is_str = "true" -%} + {%- set value = v -%} + {%- else -%} + {%- set is_str = "false" -%} + {%- set value = v | tojson -%} + {%- endif -%} + {%- set p_dsml_str = "<" ~ dsml_token ~ "parameter name=\"" ~ k ~ "\" string=\"" ~ is_str ~ "\">" ~ value ~ dsml_token ~ "parameter>" -%} + {%- set ns.P_dsml_strs = ns.P_dsml_strs + [p_dsml_str] -%} + {%- endfor -%} + {%- endif -%} + {{- ns.P_dsml_strs | join("\n") -}} +{%- endmacro -%} + +{#- Macro: render_tools -#} +{%- macro render_tools(tools) -%} + {%- set ns = namespace(tools_json=[]) -%} + {%- for tool in tools -%} + {%- if tool.function is defined -%} + {%- set ns.tools_json = ns.tools_json + [tool.function | tojson] -%} + {%- else -%} + {%- set ns.tools_json = ns.tools_json + [tool | tojson] -%} + {%- endif -%} + {%- endfor -%} + {{- TOOLS_SYSTEM_TEMPLATE | replace("{tool_schemas}", ns.tools_json | join("\n")) }} +{% endmacro -%} + +{#- Macro: find_last_user_index -#} +{%- macro find_last_user_index(messages) -%} + {%- set ns = namespace(last_user_index=-1) -%} + {%- for msg in messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role in ['user', 'developer'] -%} + {%- set ns.last_user_index = loop.index0 -%} + {%- endif -%} + {%- endfor -%} + {{- ns.last_user_index -}} +{%- endmacro -%} + +{#- Macro: render_tool_calls_content -#} +{%- macro render_tool_calls_content(tool_calls) -%} + {%- set ns = namespace(formatted_calls=[]) -%} + {%- for tool_call in tool_calls -%} + {%- if tool_call.function is defined -%} + {%- set name = tool_call.function.name -%} + {%- set arguments = tool_call.function.arguments -%} + {%- else -%} + {%- set name = tool_call.name -%} + {%- set arguments = tool_call.arguments -%} + {%- endif -%} + {%- if arguments is string -%} + {%- set arguments = arguments | fromjson -%} + {%- endif -%} + {%- set formatted_call = "<" ~ dsml_token ~ "invoke name=\"" ~ name ~ "\">\n" ~ encode_arguments_to_dsml(arguments) ~ "\n" ~ dsml_token ~ "invoke>" -%} + {%- set ns.formatted_calls = ns.formatted_calls + [formatted_call] -%} + {%- endfor -%} + {{- "<" ~ dsml_token ~ "function_calls>\n" ~ ns.formatted_calls | join("\n") ~ "\n" ~ dsml_token ~ "function_calls>" -}} +{%- endmacro -%} + +{#- Macro: render_message -#} +{%- macro render_message(index, messages, thinking_mode) -%} + {%- set msg = messages[index] -%} + {%- set last_user_idx = find_last_user_index(messages) | int -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- set content = msg.content if msg.content is defined else (msg.get('content', '') or '') -%} + {%- set msg_tools = msg.tools if msg.tools is defined else msg.get('tools', []) -%} + {%- set tool_calls = msg.tool_calls if msg.tool_calls is defined else msg.get('tool_calls', []) -%} + {%- set reasoning_content = msg.reasoning_content if msg.reasoning_content is defined else (msg.get('reasoning_content', '') or '') -%} + + {%- if role == 'system' -%} + {{- content or '' -}} + {%- if msg_tools -%} + {{- "\n\n" ~ render_tools(msg_tools) -}} + {%- endif -%} + + {%- elif role == 'user' -%} + {{- "<|User|>" ~ content ~ "<|Assistant|>" -}} + {%- if index == last_user_idx and thinking_mode == "thinking" -%} + {{- thinking_start_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} + + {%- elif role == 'tool' -%} + {%- set ns = namespace(prev_assistant_idx=-1) -%} + {%- for i in range(index - 1, -1, -1) -%} + {%- set check_role = messages[i].role if messages[i].role is defined else messages[i].get('role') -%} + {%- if check_role != 'tool' and ns.prev_assistant_idx == -1 -%} + {%- set ns.prev_assistant_idx = i -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_call_order = index - ns.prev_assistant_idx -%} + {%- set assistant_msg = messages[ns.prev_assistant_idx] -%} + {%- set assistant_tool_calls = assistant_msg.tool_calls if assistant_msg.tool_calls is defined else assistant_msg.get('tool_calls', []) -%} + {%- if tool_call_order == 1 -%} + {{- "\n\n" -}} + {%- endif -%} + {{- "\n" ~ content -}} + {%- if tool_call_order == (assistant_tool_calls | length) -%} + {{- "\n" -}} + {%- if index >= last_user_idx and thinking_mode == "thinking" -%} + {{- "\n\n" ~ thinking_start_token -}} + {%- else -%} + {{- "\n\n" ~ thinking_end_token -}} + {%- endif -%} + {%- endif -%} + + {%- elif role == 'assistant' -%} + {%- set ns = namespace(thinking_part="", tool_calls_content="") -%} + {%- if tool_calls -%} + {%- set ns.tool_calls_content = "\n\n" ~ render_tool_calls_content(tool_calls) -%} + {%- endif -%} + {%- set summary_content = content or "" -%} + {%- if thinking_mode == "thinking" and index > last_user_idx -%} + {%- set ns.thinking_part = reasoning_content ~ thinking_end_token -%} + {%- endif -%} + {{- ns.thinking_part ~ summary_content ~ ns.tool_calls_content ~ "<|end▁of▁sentence|>" -}} + {%- endif -%} +{%- endmacro -%} + +{#- Main template body -#} +{%- set full_messages = messages -%} + +{#- Handle tools in top-level (OpenAI format) -#} +{%- if tools is defined and tools is not none -%} + {%- set ns_sys = namespace(has_system=false, sys_idx=-1) -%} + {%- for msg in full_messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role == 'system' and not ns_sys.has_system -%} + {%- set ns_sys.has_system = true -%} + {%- set ns_sys.sys_idx = loop.index0 -%} + {%- endif -%} + {%- endfor -%} +{%- endif -%} + +{%- if add_default_bos_token -%} + {{- bos_token -}} +{%- endif -%} + +{#- If tools defined at top level but no system message has them, prepend tools info -#} +{%- if tools is defined and tools is not none -%} + {{- render_tools(tools) -}} +{%- endif -%} + +{%- for msg in full_messages -%} + {{- render_message(loop.index0, full_messages, thinking_mode) -}} {%- endfor -%} -{% if add_generation_prompt and not ns.is_tool%} - {% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %} - {{'<|Assistant|>'}} - {%- if not thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {% endif %} -{% endif %} diff --git a/test/test_api/test_gsmk.py b/test/test_api/test_gsmk.py new file mode 100644 index 000000000..2d9ead65b --- /dev/null +++ b/test/test_api/test_gsmk.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + +SYSTEM_PROMPT_TARGET_LEN = 18192 + + +def generate_system_prompt(): + """Generate a system prompt of approximately 8192 characters.""" + base = ( + "You are a highly capable math assistant. Your task is to solve grade school math problems step by step. " + "Show your reasoning clearly and provide the final numerical answer. " + "Break down each problem into smaller steps and verify your calculations. " + "Always end your answer with the format: #### . " + ) + # Repeat base text to reach target length + repeats = SYSTEM_PROMPT_TARGET_LEN // len(base) + 1 + prompt = (base * repeats)[:SYSTEM_PROMPT_TARGET_LEN] + return prompt + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument( + "--system-prompt", action="store_true", help="Prepend an 8192-character system prompt to each request" + ) + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + system_prefix = "" + if args.system_prompt: + system_prefix = generate_system_prompt() + "\n\n" + print(f"System prompt enabled: {len(system_prefix)} characters") + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=system_prefix + few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) From 50b45d46afcf0cb6a9c8fe2ad8cd4b7ce6a99d35 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 08:39:18 +0000 Subject: [PATCH 16/23] fix --- .../deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 3 +-- .../deepseek3_2/layer_infer/transformer_layer_infer.py | 6 +++--- .../layer_weights/nsa_indexer_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 10 +++------- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 3855bf590..61b4962f1 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -16,11 +16,10 @@ class NSAIndexerInfer(BaseLayerInfer): - def __init__(self, layer_idx, network_config, mode=[]): + def __init__(self, layer_idx, network_config): super().__init__() self.layer_idx_ = layer_idx self.network_config_ = network_config - self.mode = mode self.index_topk = network_config["index_topk"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = 1 diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index bc8bb9c6b..b7326c36e 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -14,11 +14,11 @@ class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.index_topk = network_config["index_topk"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) - self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) + self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_) self.topk_indices = None # Initialize NSA attention backend (singleton, lazy initialization) diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 9ccfbe97e..9e1337b0f 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -7,8 +7,8 @@ class NSAIndexerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg): + super().__init__(layer_num, data_type, network_config, quant_cfg) return @override diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py index 2a03e1d6a..adcba51cc 100644 --- a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -3,14 +3,10 @@ class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.index_topk = network_config["index_topk"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) self.indexer_layer_weight = NSAIndexerWeight( - layer_num=layer_num, - data_type=data_type, - network_config=network_config, - mode=mode, - quant_cfg=quant_cfg + layer_num=layer_num, data_type=data_type, network_config=network_config, quant_cfg=quant_cfg ) return From c1249d40f40c21047b375d130660d73c0c5536e0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 08:45:23 +0000 Subject: [PATCH 17/23] fix --- .../deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 61b4962f1..390853271 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -1,8 +1,6 @@ from sgl_kernel import fast_topk_transform_fused import deep_gemm import torch -import torch.nn.functional as F - from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo @@ -138,10 +136,7 @@ def _get_q_k_bf16( q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - # TODO - k = F.layer_norm( - k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps - ).type_as(k) + k = layer_weight.k_norm_(k, eps=self.eps) # Slice position_cos and position_sin to match actual token length actual_seq_len = q.shape[0] From 26950cf928f0825be0010237875f796b56d6bbcc Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 09:50:12 +0000 Subject: [PATCH 18/23] fix --- .../attention/nsa/flashmla_sparse.py | 7 +- lightllm/common/infer_utils.py | 9 +- .../triton_kernel/fp8_mqa_logits.py | 124 +++-- .../triton_kernel/token_group_quant.py | 9 +- lightllm/server/build_prompt.py | 50 +- lightllm/server/encoding_dsv32.py | 429 ++++++++++++++++++ 6 files changed, 564 insertions(+), 64 deletions(-) create mode 100644 lightllm/server/encoding_dsv32.py diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 3eec98f05..2c347ed32 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -1,3 +1,6 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py +# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library. + import dataclasses import torch from typing import Tuple, TYPE_CHECKING @@ -112,8 +115,8 @@ def _nsa_decode_att( q_nope, q_rope = q # Extract k_rope and kv_nope from the KV buffer - k_rope = kv[:, :, -qk_rope_head_dim:].reshape(-1, 1, 1, qk_rope_head_dim) - kv_nope = kv[:, :, :-qk_rope_head_dim].reshape(-1, 1, 1, kv_lora_rank) + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) o_tensor = flash_attn_with_kvcache( q=q_rope, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index 26cf973be..e1b9cc383 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,15 +1,8 @@ -import torch from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req_prefill def init_req_to_token_indexes( - req_to_token_indexs, - b_req_idx, - b_seq_len, - b_ready_cache_len, - b_start_loc, - alloc_mem_index, - max_q_seq_len, + req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, b_start_loc, alloc_mem_index, max_q_seq_len ): copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index 2fc92662a..e8f1bbfa2 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -1,3 +1,4 @@ +# import triton import triton.language as tl import torch @@ -5,13 +6,27 @@ @triton.jit def _fp8_paged_mqa_logits_kernel( - Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, - CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, - seq_len, seq_len_kv, num_heads, head_dim, - stride_q_seq, stride_q_head, stride_q_dim, - stride_kv_pool, stride_kv_dim, - stride_w_seq, stride_w_head, - stride_o_seq, stride_o_kv, + Q_ptr, + KV_ptr, + KVScale_ptr, + Weights_ptr, + MemIndex_ptr, + CuSeqlenKs_ptr, + CuSeqlenKe_ptr, + Output_ptr, + seq_len, + seq_len_kv, + num_heads, + head_dim, + stride_q_seq, + stride_q_head, + stride_q_dim, + stride_kv_pool, + stride_kv_dim, + stride_w_seq, + stride_w_head, + stride_o_seq, + stride_o_kv, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, @@ -19,80 +34,79 @@ def _fp8_paged_mqa_logits_kernel( pid_m = tl.program_id(0) pid_n = tl.program_id(1) - + # Compute the range of seq positions this block handles start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - + # Offset arrays for this block offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) - + # Initialize accumulator for logits logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + # Create masks mask_m = offs_m < seq_len mask_n = offs_n < seq_len_kv - + # Load mem_indices for the KV positions mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) - + # Load scales for K scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) - + # Loop over all heads for h in range(num_heads): # Load weights for this head - weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, - mask=mask_m, other=0.0) - + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) + # Initialize score accumulator for this head score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + # Loop over head_dim in blocks for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): d_start = d_block * BLOCK_SIZE_D offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) mask_d = offs_d < head_dim - + # Load Q for this head and dimension block # Q shape: (seq_len, num_heads, head_dim) q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) - + # Load K for this dimension block # KV shape: (pool_size, head_dim) as FP8 data k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim mask_k = mask_n[:, None] & mask_d[None, :] k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) - + # Apply scale to K (scale is per-row of K) k = k * scales[:, None] - + # Compute partial dot product: q @ k.T # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) score += tl.dot(q, tl.trans(k)) - + # Apply ReLU to score score = tl.maximum(score, 0.0) - + # Multiply by weights and accumulate to logits logits += score * weights[:, None] - + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) - + mask_lo = offs_n[None, :] >= mask_ks[:, None] mask_hi = offs_n[None, :] < mask_ke[:, None] mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] - + # Apply mask (-inf for masked positions) - logits = tl.where(mask_valid, logits, float('-inf')) - + logits = tl.where(mask_valid, logits, float("-inf")) + # Store output out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) @@ -100,40 +114,54 @@ def _fp8_paged_mqa_logits_kernel( def fp8_paged_mqa_logits( - q: torch.Tensor, + q: torch.Tensor, kv: torch.Tensor, kv_scale: torch.Tensor, - weights: torch.Tensor, - mem_index: torch.Tensor, - cu_seqlen_ks: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, - out: torch.Tensor = None + out: torch.Tensor = None, ) -> torch.Tensor: seq_len, num_heads, head_dim = q.shape seq_len_kv = mem_index.shape[0] - + if out is None: output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) else: output = out - + BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 64 - BLOCK_SIZE_D = 128 - + BLOCK_SIZE_D = 128 + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) - + _fp8_paged_mqa_logits_kernel[grid]( - q, kv, kv_scale, weights, mem_index, - cu_seqlen_ks, cu_seqlen_ke, output, - seq_len, seq_len_kv, num_heads, head_dim, - q.stride(0), q.stride(1), q.stride(2), - kv.stride(0), kv.stride(1), - weights.stride(0), weights.stride(1), - output.stride(0), output.stride(1), + q, + kv, + kv_scale, + weights, + mem_index, + cu_seqlen_ks, + cu_seqlen_ke, + output, + seq_len, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + output.stride(0), + output.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_D=BLOCK_SIZE_D, ) - - return output \ No newline at end of file + + return output diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py index dbf5c5199..807986413 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py + import triton import triton.language as tl import torch @@ -7,6 +9,7 @@ fp8_max = 448.0 fp8_dtype = torch.float8_e4m3fn + @triton.jit def _per_token_group_quant_mla_deep_gemm_masked_fp8( y_ptr, @@ -46,9 +49,7 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8( mask = cols < group_size for gid in range(NUM_GROUP): - y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( - tl.float32 - ) + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(tl.float32) _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) @@ -100,4 +101,4 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8( BLOCK_SIZE, ) - return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a5..d77184863 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,11 +1,28 @@ +import json +import os + tokenizer = None +_model_type = None def init_tokenizer(args): - global tokenizer + global tokenizer, _model_type from lightllm.server.tokenizer import get_tokenizer tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + + # Detect model type for specialized encoding (e.g. DeepSeek-V3.2) + config_path = os.path.join(args.model_dir, "config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + model_config = json.load(f) + _model_type = model_config.get("model_type", None) + # Check architectures as fallback + if _model_type is None: + archs = model_config.get("architectures", []) + if any("DeepseekV32" in a for a in archs): + _model_type = "deepseek_v32" + chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: @@ -14,9 +31,14 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: - global tokenizer + global tokenizer, _model_type # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + + # Use DeepSeek-V3.2 native encoding when applicable + if _model_type == "deepseek_v32": + return _build_prompt_dsv32(messages, tools, request) + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -40,3 +62,27 @@ async def build_prompt(request, tools) -> str: tools=tools, ) return input_str + + +def _build_prompt_dsv32(messages, tools, request): + from lightllm.server.encoding_dsv32 import encode_messages + + # Inject tools into system message if present + if tools is not None and len(tools) > 0: + wrapped_tools = [t if "function" in t else {"function": t} for t in tools] + if messages and messages[0].get("role") == "system": + messages[0]["tools"] = wrapped_tools + else: + messages.insert(0, {"role": "system", "tools": wrapped_tools}) + + # Determine thinking mode from request + thinking = False + if request.chat_template_kwargs: + thinking = request.chat_template_kwargs.get("thinking", False) or request.chat_template_kwargs.get( + "enable_thinking", False + ) + + thinking_mode = "thinking" if thinking else "chat" + drop_thinking = messages[-1]["role"] == "user" if messages else True + + return encode_messages(messages, thinking_mode=thinking_mode, drop_thinking=drop_thinking) diff --git a/lightllm/server/encoding_dsv32.py b/lightllm/server/encoding_dsv32.py new file mode 100644 index 000000000..3ac4b8371 --- /dev/null +++ b/lightllm/server/encoding_dsv32.py @@ -0,0 +1,429 @@ +# Adapted from vLLM's deepseek_v32_encoding.py +# (https://github.com/vllm-project/vllm), which was originally adapted from +# https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +# +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import json +import re +from typing import Any + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... + + +... + +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: + +{tool_schemas} + +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n' +tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n" + +tool_output_template: str = "\n{content}" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments(tool_name, tool_args): + def _decode_value(key, value, string): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools): + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages): + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message(index, messages, thinking_mode): + if not (0 <= index < len(messages)): + raise ValueError(f"Index {index} out of range for messages list of length {len(messages)}") + if thinking_mode not in ["chat", "thinking"]: + raise ValueError(f"Invalid thinking_mode `{thinking_mode}`") + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning = msg.get("reasoning") + is_prefix = msg.get("prefix", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + elif role == "developer": + if not content: + raise ValueError(f"Invalid message for role `{role}`: {msg}") + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + if not (index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant"): + raise ValueError(f"Invalid messages at {index}:\n{assistant_msg}") + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order): + raise ValueError("No tool calls but found tool output") + + if tool_call_order == 1: + prompt += "\n\n" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + if not (reasoning or tool_calls): + raise ValueError( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning/tool_calls `{msg}` after last user message" + ) + thinking_part = thinking_template.format(reasoning=reasoning or "") + thinking_end_token + + if not tool_calls and is_prefix: + prompt += summary_content + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages(messages, last_user_idx=None): + messages_wo_thinking = [] + last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages, + thinking_mode, + context=None, + drop_thinking=True, + add_default_bos_token=True, +): + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode) + + return prompt + + +def _read_until_stop(index, text, stop): + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index, text): + tool_calls = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token]) + if _ != ">\n": + raise RuntimeError("Tool call format error") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise RuntimeError("Missing special token") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise RuntimeError("Tool name format error") + tool_name = p_tool_name[0] + + tool_args = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"]) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise RuntimeError("Parameter format error") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise RuntimeError("Duplicate parameter name") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise RuntimeError("Parameter format error") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text, thinking_mode): + summary_content, reasoning, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise RuntimeError("Invalid thinking format") + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise RuntimeError("Invalid summary format") + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise RuntimeError("Unexpected content after tool calls") + + if not (len(text) == index and stop_token in [eos_token, None]): + raise RuntimeError("Unexpected content at end") + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + if sp_token in summary_content or sp_token in reasoning: + raise RuntimeError("Unexpected special token in content") + + return { + "role": "assistant", + "content": summary_content, + "reasoning": reasoning, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } From 73bf0daa26360c7c3bd3102558b6be1bd3b3dd47 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 13:33:44 +0000 Subject: [PATCH 19/23] save --- .../fused_moe/fused_moe_weight.py | 2 +- lightllm/common/req_manager.py | 26 +- .../deepseek3_2}/encoding_dsv32.py | 0 lightllm/models/deepseek3_2/infer_struct.py | 142 +++-- .../layer_infer/nsa_indexer_layer_inder.py | 8 +- .../layer_infer/transformer_layer_infer.py | 11 +- .../layer_weights/nsa_indexer_layer_weight.py | 35 +- lightllm/models/deepseek3_2/mem_manager.py | 6 - lightllm/models/deepseek3_2/model.py | 104 +++- .../destindex_copy_indexer_ks.py | 168 ----- .../triton_kernel/extract_indexer_ks.py | 265 +------- .../triton_kernel/fp8_mqa_logits.py | 26 - lightllm/server/api_cli.py | 9 +- lightllm/server/api_openai.py | 33 - lightllm/server/build_prompt.py | 50 +- lightllm/server/core/objs/sampling_params.py | 53 +- lightllm/server/function_call_parser.py | 586 ++++++------------ .../router/dynamic_prompt/radix_cache.py | 3 +- .../mode_backend/generic_post_process.py | 17 +- lightllm/server/tokenizer.py | 11 + 20 files changed, 485 insertions(+), 1070 deletions(-) rename lightllm/{server => models/deepseek3_2}/encoding_dsv32.py (100%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03..8f54e14a7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -98,7 +98,7 @@ def _init_parallel_params(self): self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ if self.enable_ep_moe: assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" - logger.info( + logger.debug( f"global_rank {self.global_rank_} layerindex {self.layer_num_} " f"redundancy_expertids: {self.redundancy_expert_ids}" ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 40c8aa993..33bdca447 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -7,6 +7,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager logger = init_logger(__name__) @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req): else: self.req_to_out_token_id_counter[req.req_idx].fill_(0) if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics: - prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True) + prompt_ids = g_pin_mem_manager.gen_from_list( + key="prompt_ids_for_penalty", + data=req.shm_req.get_prompt_ids_numpy(), + dtype=torch.int32, + ).cuda(non_blocking=True) token_id_counter( prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): cum_sum_len += len(id_to_count) p_cumsum_seq_len.append(cum_sum_len) - from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager - - p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_ids", size=len(p_token_ids), dtype=torch.int32 - ) - p_token_ids_tensor.numpy()[:] = p_token_ids - - p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_counts", size=len(p_token_counts), dtype=torch.int32 + p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32) + p_token_counts_tensor = g_pin_mem_manager.gen_from_list( + key="p_token_counts", data=p_token_counts, dtype=torch.int32 ) - p_token_counts_tensor.numpy()[:] = p_token_counts - - p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32 + p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list( + key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32 ) - p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len return ( p_token_ids_tensor.cuda(non_blocking=True), diff --git a/lightllm/server/encoding_dsv32.py b/lightllm/models/deepseek3_2/encoding_dsv32.py similarity index 100% rename from lightllm/server/encoding_dsv32.py rename to lightllm/models/deepseek3_2/encoding_dsv32.py diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index e0cca499b..779c2fc2d 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -4,7 +4,7 @@ from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -class Deepseek3_2FlashAttentionStateInfo(Deepseek2InferStateInfo): +class Deepseek3_2InferStateInfo(Deepseek2InferStateInfo): _shared_nsa_buffers = None def __init__(self): @@ -54,7 +54,7 @@ def _check_use_cuda_graph_buffers(self): and hasattr(model, "graph_max_batch_size") and hasattr(model, "graph_max_len_in_batch") and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch + and self.max_kv_seq_len <= model.graph_max_len_in_batch ): return True return False @@ -68,12 +68,13 @@ def init_some_extra_state(self, model): self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer if self.is_prefill: - pass + self._init_nsa_indexing_prefill() else: if self.b_ready_cache_len is None: self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() + buffer = None if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) @@ -84,86 +85,125 @@ def init_some_extra_state(self, model): self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") - self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) + self.nsa_cache_seqlens.copy_(self.b_kv_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 - self._init_nsa_indexing_structures() + self._init_nsa_indexing_decode(use_cuda_graph_buffers, buffer) - def _init_nsa_indexing_structures(self): - """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer. + def _init_nsa_indexing_decode(self, use_cuda_graph_buffers, buffer): + """Optimized NSA indexing for decode: b_q_seq_len=1 per request. - Fully vectorized: eliminates per-request .item() CPU-GPU syncs. + In decode, each request generates exactly 1 token, so: + - total_q_len = batch_size (no .item() needed) + - ks[i] = cumsum_offset[i], ke[i] = cumsum_offset[i] + 1 + - lengths[i] = b_seq_len[i] + - No repeat_interleave, no token_in_req math needed. """ b_seq_len = self.b_seq_len + b_req_idx = self.b_req_idx + num_seq = self.batch_size + + # Cumulative seq_len offsets for ks/ke: [0, s0, s0+s1, ...] + cum_seq = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) + + if use_cuda_graph_buffers: + model = self._model_ref() + max_seq_len = model.graph_max_len_in_batch + + # ks, ke, lengths — write directly into buffer slices + buf_ks = buffer["ks"][:num_seq] + buf_ke = buffer["ke"][:num_seq] + buf_lengths = buffer["lengths"][:num_seq] + + # ks[0] = 0, ks[i] = cum_seq[i-1] + buf_ks[0] = 0 + if num_seq > 1: + buf_ks[1:].copy_(cum_seq[: num_seq - 1]) + # ke = ks + 1 + torch.add(buf_ks, 1, out=buf_ke) + # lengths = b_seq_len + buf_lengths.copy_(b_seq_len.int()) + + self.ks = buf_ks + self.ke = buf_ke + self.lengths = buf_lengths + + # page_table: zero buffer slice, then fill valid entries + page_table = buffer["page_table_size_1"][:num_seq, :max_seq_len] + page_table.zero_() + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + # req_all_mem_index: use padded [num_seq * max_seq_len] layout + # Downstream uses ks/ke masking so padded entries are safe + max_total_seq = num_seq * max_seq_len + buf_mem = buffer["req_all_mem_index"][:max_total_seq] + buf_mem.copy_(all_rows.reshape(-1)) + self.req_all_mem_index = buf_mem + else: + # Non-CUDA-graph decode: simplified formulas, fresh tensors + max_seq_len = b_seq_len.max().item() + + # ks/ke/lengths + seq_offsets = torch.empty_like(cum_seq) + seq_offsets[0] = 0 + if num_seq > 1: + seq_offsets[1:] = cum_seq[:-1] + + self.ks = seq_offsets + self.ke = (seq_offsets + 1).int() + self.lengths = b_seq_len.int() + + # page_table and req_all_mem_index + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=b_seq_len.device) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + self.req_all_mem_index = all_rows[valid_mask] + + def _init_nsa_indexing_prefill(self): + """NSA indexing for prefill: variable q lengths, generic vectorized path.""" + b_seq_len = self.b_seq_len b_q_seq_len = self.b_q_seq_len b_req_idx = self.b_req_idx num_seq = b_req_idx.shape[0] device = b_seq_len.device - # Only 3 scalar syncs needed (for tensor shapes) max_seq_len = b_seq_len.max().item() total_q_len = b_q_seq_len.sum().item() - total_seq_len = b_seq_len.sum().item() - # --- page_table_size_1 and req_all_mem_index (vectorized gather) --- + # page_table_size_1 and req_all_mem_index all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] seq_range = torch.arange(max_seq_len, device=device) valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) - # page_table_size_1: [batch, max_seq_len] zero-padded memory indices page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + self.req_all_mem_index = all_rows[valid_mask] - # req_all_mem_index: flattened valid memory indices across all requests - req_all_mem_index = all_rows[valid_mask] - - # --- ks, ke, lengths (vectorized computation) --- - # Cumulative seq_len offsets: [0, seq_len[0], seq_len[0]+seq_len[1], ...] + # ks, ke, lengths — generic vectorized for variable q lengths cum_seq = torch.cumsum(b_seq_len, dim=0) seq_offsets = torch.zeros_like(cum_seq) seq_offsets[1:] = cum_seq[:-1] - # Expand per-request values to per-token using repeat_interleave req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) - # Token position within each request's q_seq cum_q = torch.cumsum(b_q_seq_len, dim=0) q_offsets = torch.zeros_like(cum_q) q_offsets[1:] = cum_q[:-1] token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] - # ks[t] = seq_offset of request owning token t - # ke[t] = seq_offset + position_in_q + 1 - # lengths[t] = seq_len - q_seq_len + position_in_q + 1 - ks = seq_offsets[req_indices].int() - ke = (seq_offsets[req_indices] + token_in_req + 1).int() - lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() - - # --- Assign results (CUDA graph buffer or new tensors) --- - use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() - - if use_cuda_graph_buffers: - model = self._model_ref() - buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) - buffer = buffers[self.microbatch_index] - - self.ks = buffer["ks"][:total_q_len] - self.ke = buffer["ke"][:total_q_len] - self.lengths = buffer["lengths"][:total_q_len] - self.page_table_size_1 = buffer["page_table_size_1"][:num_seq, :max_seq_len] - self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] - - self.ks.copy_(ks) - self.ke.copy_(ke) - self.lengths.copy_(lengths) - self.page_table_size_1.copy_(page_table) - self.req_all_mem_index.copy_(req_all_mem_index) - else: - self.ks = ks - self.ke = ke - self.lengths = lengths - self.page_table_size_1 = page_table - self.req_all_mem_index = req_all_mem_index + self.ks = seq_offsets[req_indices].int() + self.ke = (seq_offsets[req_indices] + token_in_req + 1).int() + self.lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 390853271..7a9aeb46c 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -3,7 +3,7 @@ import torch from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks @@ -70,7 +70,7 @@ def get_indices( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: NSAIndexerWeight, ) -> torch.Tensor: @@ -113,7 +113,7 @@ def get_indices( score=logits, lengths=lengths, page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, topk=self.index_topk, ) @@ -130,7 +130,7 @@ def _get_q_k_bf16( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: NSAIndexerWeight, ): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b7326c36e..9dba923cc 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -5,7 +5,7 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -37,7 +37,7 @@ def _get_nsa_backend(self): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -47,9 +47,6 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - # Process all tokens for indexer - # Note: Prefix cache slicing optimization is disabled due to batch structure - # mismatch issues with fast_topk_transform_fused kernel self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) q = layer_weight.q_b_proj_.mm(q) @@ -76,7 +73,7 @@ def _context_attention_kernel( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -111,7 +108,7 @@ def _context_attention_kernel( def _token_attention_kernel( self, q, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ): diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 9e1337b0f..6df1a8821 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -11,40 +11,47 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__(layer_num, data_type, network_config, quant_cfg) return + @override + def _parse_config(self): + self.q_lora_rank = self.network_config_["q_lora_rank"] + self.index_n_heads = self.network_config_["index_n_heads"] + self.index_head_dim = self.network_config_["index_head_dim"] + self.hidden_size = self.network_config_["hidden_size"] + @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" self.wq_b_proj_ = ROWMMWeight( - weight_name=f"{prefix}.wq_b.weight", + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{prefix}.wq_b.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="wq_b", + quant_method=None, tp_rank=0, tp_world_size=1, ) self.wk_proj_ = ROWMMWeight( - weight_name=f"{prefix}.wk.weight", + in_dim=self.hidden_size, + out_dims=[self.index_head_dim], + weight_names=f"{prefix}.wk.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="wk", + quant_method=None, tp_rank=0, tp_world_size=1, ) self.k_norm_ = LayerNormWeight( - dim=self.network_config_["index_head_dim"], + dim=self.index_head_dim, weight_name=f"{prefix}.k_norm.weight", - data_type=torch.float32, + data_type=self.data_type_, bias_name=f"{prefix}.k_norm.bias", ) self.weights_proj_ = ROWMMWeight( - weight_name=f"{prefix}.weights_proj.weight", + in_dim=self.hidden_size, + out_dims=[self.index_n_heads], + weight_names=f"{prefix}.weights_proj.weight", data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="weights_proj", + quant_method=None, tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index 8017a84ad..dc78f1de4 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -6,12 +6,6 @@ class IndexerKSBuffer: - """Lightweight buffer holder for NSA indexer keys+scales. - - Shares token indices with the parent MemoryManager — does NOT have its - own allocator. Only stores the per-layer kv_buffer tensor. - """ - def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index d25cbd378..f907b0bed 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,11 +1,107 @@ +import copy +import json +import logging + from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager +_logger = logging.getLogger(__name__) + + +class DeepSeekV32Tokenizer: + """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based + encoding_dsv32 module instead of Jinja chat templates. + + DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat + template, so ``apply_chat_template`` would fail without either a manually + supplied ``--chat_template`` file or this wrapper. Activate it with + ``--tokenizer_mode deepseek_v32``. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + # Cache added vocabulary for performance (HuggingFace can be slow). + self._added_vocab = None + + # ------------------------------------------------------------------ + # Attribute delegation – everything not overridden goes to the inner + # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. + # ------------------------------------------------------------------ + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab + + # ------------------------------------------------------------------ + # Core override: route apply_chat_template through encode_messages. + # ------------------------------------------------------------------ + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + **kwargs, + ): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools + + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + # Deep copy to avoid mutating the caller's messages. + msgs = copy.deepcopy(msgs) + + # Determine thinking mode. + thinking_mode = "thinking" if thinking else "chat" + + # Inject tools into the first system message (or create one) so that + # encode_messages / render_message picks them up. + if tools: + # build_prompt passes tools as bare function dicts: + # [{"name": "f", "description": "...", "parameters": {...}}] + # encoding_dsv32's render_message expects OpenAI wrapper format: + # [{"type": "function", "function": {...}}] + wrapped_tools = [] + for t in tools: + if "function" in t: + wrapped_tools.append(t) + else: + wrapped_tools.append({"type": "function", "function": t}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + # Prepend a system message that carries the tools. + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + prompt = encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt + @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): @@ -16,7 +112,7 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer # infer state class - infer_state_class = Deepseek3_2FlashAttentionStateInfo + infer_state_class = Deepseek3_2InferStateInfo def __init__(self, kvargs): super().__init__(kvargs) @@ -24,11 +120,11 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + self.infer_state_class = Deepseek3_2InferStateInfo def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager - if "triton_fp8kv" in self.mode: + if get_env_start_args().llm_kv_type == "fp8kv": manager_class = Deepseek3_2FP8KVMemoryManager # mtp 模式下需要在mem manger上扩展draft model使用的layer diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index 8faf3cdea..a345bd1e2 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -19,16 +19,6 @@ def _fwd_kernel_destindex_copy_indexer_ks( stride_o_d, BLOCK_DMODEL: tl.constexpr, ): - """ - Triton kernel to copy FP8 K values and their scales to an indexed output buffer. - - This kernel reads FP8 key values (128 dims) and their float32 scale values, - then writes them to a compact buffer format where each entry contains: - - Bytes 0-127: FP8 key values (128 bytes) - - Bytes 128-131: Float32 scale (4 bytes) - - The destination location for each source element is specified by DestLoc. - """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -64,37 +54,6 @@ def _fwd_kernel_destindex_copy_indexer_ks( def destindex_copy_indexer_ks( K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor ): - """ - Copy FP8-quantized key values and their scales to indexed locations in a buffer. - - This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) - mechanism to store compressed key representations in a memory buffer. Each key - is stored with its FP8 representation (128 bytes) followed by its float32 scale - (4 bytes), for a total of 132 bytes per key. - - Args: - K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn - FP8-quantized key values - K_scale: [q_seq_len, 1] torch.float32 - Quantization scales for each key - DestLoc: [q_seq_len] torch.int32 - Destination indices in the output buffer - O_buffer: [large_size, 1, 132] torch.uint8 - Output buffer where keys and scales will be written. - Must be a uint8 tensor to allow mixed-type storage. - Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - - Returns: - None (modifies O_buffer in-place) - - Example: - >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() - >>> k_scale = torch.randn(50, 1).cuda() - >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() - >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() - >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) - >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data - """ seq_len = DestLoc.shape[0] head_dim = K_fp8.shape[1] @@ -129,130 +88,3 @@ def destindex_copy_indexer_ks( num_stages=1, ) return - - -def test_destindex_copy_indexer_ks(): - """Test the destindex_copy_indexer_ks kernel""" - import torch.nn.functional as F - - print("=" * 80) - print("Testing destindex_copy_indexer_ks") - print("=" * 80) - - # Test parameters - q_seq_len = 50 - head_dim = 128 - large_size = 1024 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create random destination indices - dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() - actual_seq_len = len(dest_loc) - - # Create input tensors - k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - # Create output buffer (as uint8 to allow reinterpretation) - o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - - # Run kernel - destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) - - # Extract results - k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) - - # Extract scale by reinterpreting 4 bytes as float32 - scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() - k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) - - # Verify results at destination locations - k_fp8_extracted = k_fp8_out[dest_loc] - k_scale_extracted = k_scale_out[dest_loc] - - # Check FP8 values match - fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) - - # Check scales match - scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) - - # Check dequantized values - k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() - - print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test edge cases - print("Testing edge cases...") - - # Test with sequential indices - dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) - - k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) - scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() - k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_seq = torch.allclose(k_fp8_out_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) - scale_match_seq = torch.allclose(k_scale_out_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Edge case tests passed!") - print() - - # Test with single element - print("Testing single element...") - dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = ( - (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - ) - - o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) - - k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) - scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() - k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_single = torch.allclose( - k_fp8_out_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 - ) - scale_match_single = torch.allclose(k_scale_out_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - - -if __name__ == "__main__": - test_destindex_copy_indexer_ks() diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index eb22fbb8f..48bc34ad6 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -6,10 +6,10 @@ @triton.jit def _fwd_kernel_extract_indexer_ks( - I_buffer, # Input buffer [large_size, 1, 132] uint8 - SrcLoc, # Source indices [req_size] int32/int64 - O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn - O_scale, # Output scale [req_size] float32 + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 stride_i_bs, stride_i_h, stride_i_d, @@ -18,98 +18,51 @@ def _fwd_kernel_extract_indexer_ks( stride_o_scale_bs, BLOCK_DMODEL: tl.constexpr, ): - """ - Triton kernel to extract FP8 K values and their scales from an indexed buffer. - - This kernel is the inverse of destindex_copy_indexer_ks. It reads from a - compact buffer format where each entry contains: - - Bytes 0-127: FP8 key values (128 bytes) - - Bytes 128-131: Float32 scale (4 bytes) - - The source location for each output element is specified by SrcLoc. - """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) - - # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) - - # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d k_fp8_as_uint8 = tl.load(i_k_ptrs) - - # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) - - # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d tl.store(o_k_ptrs, k_fp8) - - # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) - # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d - - # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) - - # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) - - # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) - - # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs tl.store(o_scale_ptr, k_scale) - + return @torch.no_grad() def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Extract FP8-quantized key values and their scales from indexed locations in a buffer. - - This function is the inverse operation of destindex_copy_indexer_ks. It's used in - the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve - compressed key representations from a memory buffer. - - Args: - I_buffer: [large_size, 1, 132] torch.uint8 - Input buffer containing packed FP8 keys and float32 scales. - Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - SrcLoc: [req_size] torch.int32 or torch.int64 - Source indices to extract from the input buffer - - Returns: - tuple containing: - - K_fp8: [req_size, 128] torch.float8_e4m3fn - FP8-quantized key values - - K_scale: [req_size] torch.float32 - Quantization scales for each key - - Example: - >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() - >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() - >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) - >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] - """ req_size = SrcLoc.shape[0] head_dim = 128 - + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" - + # Allocate output tensors O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) - + grid = (req_size,) num_warps = 1 - + _fwd_kernel_extract_indexer_ks[grid]( I_buffer, SrcLoc, @@ -125,185 +78,5 @@ def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[to num_warps=num_warps, num_stages=1, ) - - return O_fp8, O_scale - - -def test_extract_indexer_ks(): - """Test the extract_indexer_ks kernel against the copy kernel""" - import torch.nn.functional as F - from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks - - print("=" * 80) - print("Testing extract_indexer_ks") - print("=" * 80) - - # Test parameters - q_seq_len = 50 - head_dim = 128 - large_size = 1024 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create random indices for writing - write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() - actual_seq_len = len(write_indices) - - # Create input tensors - k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_original = (k_bf16_original / k_abs_max).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - # Create buffer and write data using destindex_copy_indexer_ks - buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) - - # Now extract the data back using extract_indexer_ks - k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) - - # Verify FP8 values match - fp8_match = torch.allclose( - k_fp8_extracted.to(torch.float32), - k_fp8_original.to(torch.float32), - atol=0, rtol=0 - ) - - # Verify scales match - scale_match = torch.allclose( - k_scale_extracted, - k_scale_original.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - # Check dequantized values - k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() - - print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test with sequential indices - print("Testing sequential indices...") - write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) - k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) - - fp8_match_seq = torch.allclose( - k_fp8_ext_seq.to(torch.float32), - k_fp8_seq.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_seq = torch.allclose( - k_scale_ext_seq, - k_scale_seq.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Sequential test passed!") - print() - - # Test with single element - print("Testing single element...") - write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) - k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) - - fp8_match_single = torch.allclose( - k_fp8_ext_single.to(torch.float32), - k_fp8_single.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_single = torch.allclose( - k_scale_ext_single, - k_scale_single.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - # Test with larger batch to check performance characteristics - print("Testing larger batch (performance check)...") - write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() - actual_large_len = len(write_indices_large) - k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") - k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) - - # Warm up - for _ in range(3): - _ = extract_indexer_ks(buffer_large, write_indices_large) - - # Time it - torch.cuda.synchronize() - import time - start = time.time() - for _ in range(100): - k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) - torch.cuda.synchronize() - elapsed = time.time() - start - - fp8_match_large = torch.allclose( - k_fp8_ext_large.to(torch.float32), - k_fp8_large.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_large = torch.allclose( - k_scale_ext_large, - k_scale_large.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") - print(f" Average time per call: {elapsed/100*1000:.3f} ms") - assert fp8_match_large and scale_match_large - print("✓ Large batch test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - -if __name__ == "__main__": - test_extract_indexer_ks() + return O_fp8, O_scale diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index e8f1bbfa2..1c1f72b7d 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -1,4 +1,3 @@ -# import triton import triton.language as tl import torch @@ -35,68 +34,44 @@ def _fp8_paged_mqa_logits_kernel( pid_m = tl.program_id(0) pid_n = tl.program_id(1) - # Compute the range of seq positions this block handles start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - # Offset arrays for this block offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) - # Initialize accumulator for logits logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Create masks mask_m = offs_m < seq_len mask_n = offs_n < seq_len_kv - # Load mem_indices for the KV positions mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) - # Load scales for K scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) - # Loop over all heads for h in range(num_heads): - # Load weights for this head weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) - - # Initialize score accumulator for this head score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Loop over head_dim in blocks for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): d_start = d_block * BLOCK_SIZE_D offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) mask_d = offs_d < head_dim - # Load Q for this head and dimension block - # Q shape: (seq_len, num_heads, head_dim) q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) - # Load K for this dimension block - # KV shape: (pool_size, head_dim) as FP8 data k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim mask_k = mask_n[:, None] & mask_d[None, :] k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) - # Apply scale to K (scale is per-row of K) k = k * scales[:, None] - # Compute partial dot product: q @ k.T - # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) - # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) score += tl.dot(q, tl.trans(k)) - - # Apply ReLU to score score = tl.maximum(score, 0.0) - - # Multiply by weights and accumulate to logits logits += score * weights[:, None] - # Apply mask based on cu_seqlen_ks and cu_seqlen_ke mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) @@ -104,7 +79,6 @@ def _fp8_paged_mqa_logits_kernel( mask_hi = offs_n[None, :] < mask_ke[:, None] mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] - # Apply mask (-inf for masked positions) logits = tl.where(mask_valid, logits, float("-inf")) # Store output diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1661a3b87..877f029a7 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -92,9 +92,12 @@ def make_argument_parser() -> argparse.ArgumentParser: "--tokenizer_mode", type=str, default="fast", - help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, - slow mode is good for debug and test, fast mode get best performance, auto mode will - try to use fast mode, if failed will use slow mode""", + help="""tokenizer load mode, can be slow, fast, auto, or deepseek_v32. + slow mode load fast but run slow, good for debug and test. + fast mode get best performance. + auto mode will try to use fast mode, if failed will use slow mode. + deepseek_v32 mode wraps the tokenizer with Python-based DSML chat + template encoding for DeepSeek-V3.2 models (no --chat_template needed).""", ) parser.add_argument( "--load_way", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 928d840c8..d91bb1d94 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -463,39 +463,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") # Additional usage chunk - # Finalize any pending tool calls (e.g., DSML format last invoke) - if request.tool_choice != "none" and request.tools and parser_dict: - for _idx, _parser in parser_dict.items(): - _, finalize_calls = _parser.finalize_stream() - history_tool_calls_cnt = _get_history_tool_calls_cnt(request) - for call_item in finalize_calls: - if call_item.name: - tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt) - function_name = call_item.name - else: - tool_call_id = None - function_name = None - tool_call = ToolCall( - id=tool_call_id, - index=getattr(call_item, "tool_index", None), - function=FunctionResponse( - name=function_name, - arguments=call_item.parameters, - ), - ) - choice_data = ChatCompletionStreamResponseChoice( - index=0, - delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", - ) - chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( prompt_tokens=prompt_tokens, diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index d77184863..f770459a5 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,28 +1,11 @@ -import json -import os - tokenizer = None -_model_type = None def init_tokenizer(args): - global tokenizer, _model_type + global tokenizer from lightllm.server.tokenizer import get_tokenizer tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - - # Detect model type for specialized encoding (e.g. DeepSeek-V3.2) - config_path = os.path.join(args.model_dir, "config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - model_config = json.load(f) - _model_type = model_config.get("model_type", None) - # Check architectures as fallback - if _model_type is None: - archs = model_config.get("architectures", []) - if any("DeepseekV32" in a for a in archs): - _model_type = "deepseek_v32" - chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: @@ -31,14 +14,9 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: - global tokenizer, _model_type + global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] - - # Use DeepSeek-V3.2 native encoding when applicable - if _model_type == "deepseek_v32": - return _build_prompt_dsv32(messages, tools, request) - kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -62,27 +40,3 @@ async def build_prompt(request, tools) -> str: tools=tools, ) return input_str - - -def _build_prompt_dsv32(messages, tools, request): - from lightllm.server.encoding_dsv32 import encode_messages - - # Inject tools into system message if present - if tools is not None and len(tools) > 0: - wrapped_tools = [t if "function" in t else {"function": t} for t in tools] - if messages and messages[0].get("role") == "system": - messages[0]["tools"] = wrapped_tools - else: - messages.insert(0, {"role": "system", "tools": wrapped_tools}) - - # Determine thinking mode from request - thinking = False - if request.chat_template_kwargs: - thinking = request.chat_template_kwargs.get("thinking", False) or request.chat_template_kwargs.get( - "enable_thinking", False - ) - - thinking_mode = "thinking" if thinking else "chat" - drop_thinking = messages[-1]["role"] == "user" if messages else True - - return encode_messages(messages, thinking_mode=thinking_mode, drop_thinking=drop_thinking) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d7fd35961..d955aa6a8 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -331,38 +331,31 @@ class SamplingParams(ctypes.Structure): _top_p: float = 1.0 _top_k: int = -1 # -1 is for all - @staticmethod - def _get(kwargs, key, default): - """Get value from kwargs, falling back to default when value is None or missing.""" - val = kwargs.get(key) - return val if val is not None else default - def init(self, tokenizer, **kwargs): super().__init__() - _get = SamplingParams._get - self.best_of = _get(kwargs, "best_of", 1) - self.n = _get(kwargs, "n", self.best_of) - self.do_sample = _get(kwargs, "do_sample", SamplingParams._do_sample) - self.presence_penalty = _get(kwargs, "presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = _get(kwargs, "frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = _get(kwargs, "repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = _get(kwargs, "temperature", SamplingParams._temperature) - self.top_p = _get(kwargs, "top_p", SamplingParams._top_p) - self.top_k = _get(kwargs, "top_k", SamplingParams._top_k) - self.ignore_eos = _get(kwargs, "ignore_eos", False) - self.image_max_patch_num = _get(kwargs, "image_max_patch_num", -1) - self.max_new_tokens = _get(kwargs, "max_new_tokens", 16) - self.min_new_tokens = _get(kwargs, "min_new_tokens", 1) - self.input_penalty = _get(kwargs, "input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = _get(kwargs, "group_request_id", -1) - self.suggested_dp_index = _get(kwargs, "suggested_dp_index", -1) - - self.skip_special_tokens = _get(kwargs, "skip_special_tokens", SKIP_SPECIAL_TOKENS) - self.disable_prompt_cache = _get(kwargs, "disable_prompt_cache", False) - - self.add_special_tokens = _get(kwargs, "add_special_tokens", True) - self.add_spaces_between_special_tokens = _get(kwargs, "add_spaces_between_special_tokens", True) - self.print_eos_token = _get(kwargs, "print_eos_token", False) + self.best_of = kwargs.get("best_of", 1) + self.n = kwargs.get("n", self.best_of) + self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = kwargs.get("temperature", SamplingParams._temperature) + self.top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_k = kwargs.get("top_k", SamplingParams._top_k) + self.ignore_eos = kwargs.get("ignore_eos", False) + self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) + self.max_new_tokens = kwargs.get("max_new_tokens", 16) + self.min_new_tokens = kwargs.get("min_new_tokens", 1) + self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = kwargs.get("group_request_id", -1) + self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) + + self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) + self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) + + self.add_special_tokens = kwargs.get("add_special_tokens", True) + self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) + self.print_eos_token = kwargs.get("print_eos_token", False) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index c3faf21e7..3a8fddf74 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1453,27 +1453,26 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami class DeepSeekV32Detector(BaseFormatDetector): """ - Detector for DeepSeek V3.2 model function call format (DSML). - - DeepSeek V3.2 uses a new DSML (DeepSeek Markup Language) format for tool calls, - which is XML-like rather than JSON-based. + Detector for DeepSeek V3.2 model function call format using DSML + (DeepSeek Markup Language). Format Structure: ``` <|DSML|function_calls> <|DSML|invoke name="get_weather"> - <|DSML|parameter name="location" string="true">杭州 - <|DSML|parameter name="date" string="true">2024-01-16 - <|DSML|invoke name="get_weather"> - <|DSML|parameter name="location" string="true">北京 - <|DSML|parameter name="date" string="true">2024-01-16 + <|DSML|parameter name="location" string="true">Hangzhou + <|DSML|parameter name="date" string="true">2024-01-16 + + ``` Key Components: - - Tool Calls Section: Starts with `<|DSML|function_calls>` - - Individual Invoke: `<|DSML|invoke name="function_name">` - - Parameters: `<|DSML|parameter name="param_name" string="true">value` - - Parameter types are inferred from the tool schema for proper JSON serialization + - Function Calls Block: `<|DSML|function_calls>` ... `` + - Individual Invocation: `<|DSML|invoke name="func">` ... `` + - Parameters: `<|DSML|parameter name="key" string="true|false">value` + - string="true": value is plain text (will be JSON-escaped) + - string="false": value is JSON (numbers, booleans, arrays, objects) + - Supports multiple parallel tool calls Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 """ @@ -1481,333 +1480,132 @@ class DeepSeekV32Detector(BaseFormatDetector): def __init__(self): super().__init__() self.dsml_token = "|DSML|" - self.bot_token = "<|DSML|function_calls>" - self.eot_token = "" # DSML format has no explicit end token - self.invoke_prefix = '<|DSML|invoke name="' - self.parameter_prefix = '<|DSML|parameter name="' - - # Regex for complete parsing + self.bot_token = f"<{self.dsml_token}function_calls>" + self.eot_token = f"" + self.invoke_start_prefix = f"<{self.dsml_token}invoke" + self.invoke_end_token = f"" + self.param_end_token = f"" + + # Regex for complete invoke extraction + _de = re.escape(self.dsml_token) self.invoke_regex = re.compile( - r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)(?=<|DSML|invoke|$)', + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*?)', re.DOTALL, ) - # Captures: (param_name, is_string, value) - self.parameter_regex = re.compile( - r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)(?=<|DSML|parameter|<|DSML|invoke|$)', + # Regex for parameter extraction + self.param_regex = re.compile( + rf'<{_de}parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for partial invoke (name known, body still streaming) + self.partial_invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*)', re.DOTALL, ) - # Streaming state self._last_arguments = "" - self._current_invoke_text = "" - self._invoke_count = 0 - self._param_count_in_invoke = 0 - self._accumulated_params: Dict[str, str] = {} - self._json_started = False - self._tools_schema: Optional[Dict[str, Dict]] = None - self._tool_indices: Optional[Dict[str, int]] = None - self._current_func_name: Optional[str] = None - self._in_tool_call_sequence = False # Set True once bot_token seen + self._accumulated_params: List[tuple] = [] + self._in_function_calls = False # Track if we're inside a function_calls block def has_tool_call(self, text: str) -> bool: - """Check if the text contains a DeepSeek V3.2 DSML format tool call.""" return self.bot_token in text - def _get_param_type(self, func_name: str, param_name: str, tools: List[Tool]) -> str: - """Get the JSON Schema type of a parameter from the tool definition.""" - if self._tools_schema is None: - self._tools_schema = {} - for tool in tools: - if tool.function.name and tool.function.parameters: - props = tool.function.parameters.get("properties", {}) - self._tools_schema[tool.function.name] = props - - func_schema = self._tools_schema.get(func_name, {}) - param_schema = func_schema.get(param_name, {}) - return param_schema.get("type", "string") - - def _convert_param_value(self, value: str, is_string_attr: str, param_type: str) -> Any: - """Convert a raw parameter value string to the appropriate Python type. - - Args: - value: The raw string value from the DSML parameter tag. - is_string_attr: The "string" attribute from DSML ("true" or "false"). - If "true", the value is treated as a raw string. - If "false", the value is parsed based on param_type or JSON. - param_type: The JSON Schema type from the tool definition (fallback). - """ - value = value.strip() - if value.lower() == "null": - return None - - # Use DSML string attribute as primary signal - if is_string_attr == "true": - return value - - # string="false" - parse based on schema type or attempt JSON - param_type = param_type.lower() - if param_type in ("integer", "int"): - try: - return int(value) - except (ValueError, TypeError): - return value - elif param_type in ("number", "float"): - try: - val = float(value) - # Only coerce to int if it's actually an integer string - if "." not in value and "e" not in value.lower(): - return int(value) - return val - except (ValueError, TypeError, OverflowError): - return value - elif param_type in ("boolean", "bool"): - lower = value.lower() - if lower in ("true", "1"): - return True - elif lower in ("false", "0"): - return False + def _dsml_params_to_json(self, params: List[tuple]) -> str: + """Convert DSML parameter tuples (name, is_str, value) to a JSON arguments string.""" + args = {} + for name, is_str, value in params: + if is_str == "true": + args[name] = value else: - logger.warning(f"Unexpected boolean value: {value!r}, treating as string") - return value - elif param_type in ("object", "array"): - try: - return json.loads(value) - except json.JSONDecodeError: - return value - else: - # Unknown type with string="false" - try JSON parse, fallback to string - try: - return json.loads(value) - except json.JSONDecodeError: - return value - - def _parse_invoke_params(self, invoke_content: str, func_name: str, tools: List[Tool]) -> Dict: - """Parse all parameters from an invoke block content.""" - params = {} - for param_name, is_string_attr, param_value in self.parameter_regex.findall(invoke_content): - param_name = param_name.strip() - param_value = param_value.strip() - param_type = self._get_param_type(func_name, param_name, tools) - params[param_name] = self._convert_param_value(param_value, is_string_attr, param_type) - return params + try: + args[name] = json.loads(value) + except (json.JSONDecodeError, ValueError): + args[name] = value + return json.dumps(args, ensure_ascii=False) def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - One-time parsing: Detects and parses DSML tool calls in the provided text. - """ - if self.bot_token not in text: - return StreamingParseResult(normal_text=text, calls=[]) - + """One-time parsing for DSML format tool calls.""" idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx > 0 else "" - tool_section = text[idx:] + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) tool_indices = self._get_tool_indices(tools) calls = [] - try: - for func_name, invoke_content in self.invoke_regex.findall(tool_section): - func_name = func_name.strip() - if func_name not in tool_indices: - logger.warning(f"Model attempted to call undefined function: {func_name}") - continue + invoke_matches = self.invoke_regex.findall(text) + for func_name, invoke_body in invoke_matches: + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue - params = self._parse_invoke_params(invoke_content, func_name, tools) - calls.append( - ToolCallItem( - tool_index=tool_indices[func_name], - name=func_name, - parameters=json.dumps(params, ensure_ascii=False), - ) - ) - return StreamingParseResult(normal_text=normal_text, calls=calls) - except Exception as e: - logger.error(f"Error in DeepSeekV32 detect_and_parse: {e}") - return StreamingParseResult(normal_text=text) + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) - def finalize_streaming(self, tools: List[Tool]) -> StreamingParseResult: - """Finalize the last pending tool call when generation ends (EOS). + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=args_json, + ) + ) - The DSML format has no explicit end token, so the last invoke's last - parameter may remain unconfirmed. This method should be called when - the stream ends to close any open JSON and emit remaining parameters. - """ - if not self.current_tool_name_sent or self.current_tool_id < 0: - return StreamingParseResult() + return StreamingParseResult(normal_text=normal_text, calls=calls) - calls: List[ToolCallItem] = [] + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """Streaming incremental parsing for DSML format tool calls.""" + self._buffer += new_text current_text = self._buffer - try: - # Find current invoke text - invoke_positions = [] - search_start = 0 - while True: - pos = current_text.find(self.invoke_prefix, search_start) - if pos == -1: - break - invoke_positions.append(pos) - search_start = pos + len(self.invoke_prefix) - - if self._invoke_count < len(invoke_positions): - invoke_start = invoke_positions[self._invoke_count] - invoke_text = current_text[invoke_start:] - - name_content_start = len(self.invoke_prefix) - name_end = invoke_text.find('">', name_content_start) - if name_end != -1: - func_name = invoke_text[name_content_start:name_end].strip() - invoke_body = invoke_text[name_end + 2 :] - - # Parse all remaining params (including the last unconfirmed one) - param_matches = list(self.parameter_regex.finditer(invoke_body)) - for i in range(self._param_count_in_invoke, len(param_matches)): - match = param_matches[i] - param_name = match.group(1).strip() - is_string_attr = match.group(2) - param_value = match.group(3).strip() - - param_type = self._get_param_type(func_name, param_name, tools) - converted_value = self._convert_param_value(param_value, is_string_attr, param_type) - serialized_value = json.dumps(converted_value, ensure_ascii=False) - - if not self._json_started: - json_fragment = "{" + f'"{param_name}": {serialized_value}' - self._json_started = True - else: - json_fragment = f', "{param_name}": {serialized_value}' + # Check if we're inside a function_calls block or starting one + has_tool = self.has_tool_call(current_text) or self._in_function_calls - self._accumulated_params[param_name] = converted_value - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=json_fragment, - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += json_fragment + if not has_tool: + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() - # Close the JSON object - if self._json_started: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += "}" - elif self.current_tool_name_sent: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="{}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] = "{}" - - # Update prev_tool_call_arr - if self.current_tool_id < len(self.prev_tool_call_arr): - self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params - - # Reset state - self._invoke_count += 1 - self.current_tool_id += 1 - self.current_tool_name_sent = False - self._json_started = False - self._accumulated_params = {} self._buffer = "" + for e_token in [self.eot_token, self.invoke_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) - return StreamingParseResult(normal_text="", calls=calls) - except Exception as e: - logger.error(f"Error in DeepSeekV32 finalize_streaming: {e}") - return StreamingParseResult(normal_text="", calls=calls) - - def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: - """ - Streaming incremental parsing for DeepSeek V3.2 DSML tool calls. + # Mark that we're inside a function_calls block + if self.has_tool_call(current_text): + self._in_function_calls = True - The DSML format streams line-by-line with invoke/parameter tokens. - We accumulate parameters and only emit JSON fragments when a parameter's - value is confirmed complete (by seeing the next parameter/invoke boundary). - """ - self._buffer += new_text - current_text = self._buffer + # Check if function_calls block has ended + if self.eot_token in current_text: + self._in_function_calls = False - # Check if we have any DSML content - if not self._in_tool_call_sequence: - if not self.has_tool_call(current_text): - # Check for partial start token - if self._ends_with_partial_token(current_text, self.bot_token): - return StreamingParseResult() - self._buffer = "" - return StreamingParseResult(normal_text=new_text) - self._in_tool_call_sequence = True - - if self._tool_indices is None: + if not hasattr(self, "_tool_indices"): self._tool_indices = self._get_tool_indices(tools) calls: List[ToolCallItem] = [] try: - # Find all invoke starts in current buffer - invoke_positions = [] - search_start = 0 - while True: - pos = current_text.find(self.invoke_prefix, search_start) - if pos == -1: - break - invoke_positions.append(pos) - search_start = pos + len(self.invoke_prefix) - - if not invoke_positions: - # Have bot_token but no invoke yet - keep buffering - return StreamingParseResult() - - # Process only the current (latest) invoke block - current_invoke_idx = self._invoke_count - if current_invoke_idx >= len(invoke_positions): - # All invokes already processed, keep buffering for new ones - return StreamingParseResult() - - invoke_start = invoke_positions[current_invoke_idx] - # Whether the current invoke is bounded by a next invoke - invoke_is_bounded = current_invoke_idx + 1 < len(invoke_positions) - if invoke_is_bounded: - invoke_end = invoke_positions[current_invoke_idx + 1] - else: - invoke_end = len(current_text) - - invoke_text = current_text[invoke_start:invoke_end] + # Try to find complete invoke blocks first + complete_invoke_match = self.invoke_regex.search(current_text) + if complete_invoke_match: + func_name = complete_invoke_match.group(1) + invoke_body = complete_invoke_match.group(2) - # Extract function name - name_start = invoke_text.find(self.invoke_prefix) - if name_start == -1: - return StreamingParseResult() - - name_content_start = name_start + len(self.invoke_prefix) - name_end = invoke_text.find('">', name_content_start) - if name_end == -1: - # Function name not complete yet - return StreamingParseResult() - - func_name = invoke_text[name_content_start:name_end].strip() + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] - # Initialize state for this tool call - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) - # Send tool name if not sent yet - if not self.current_tool_name_sent: - if func_name and func_name in self._tool_indices: + if not self.current_tool_name_sent: calls.append( ToolCallItem( tool_index=self.current_tool_id, @@ -1816,109 +1614,101 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami ) ) self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id] = { - "name": func_name, - "arguments": {}, - } - self._current_func_name = func_name - self._accumulated_params = {} - self._param_count_in_invoke = 0 - self._json_started = False - return StreamingParseResult(calls=calls) - return StreamingParseResult() - - # Parse parameters from the invoke block content - invoke_body = invoke_text[name_end + 2 :] # after '">' - - # Find all parameter starts within this invoke body - param_positions = [] - ps = 0 - while True: - pp = invoke_body.find(self.parameter_prefix, ps) - if pp == -1: - break - param_positions.append(pp) - ps = pp + len(self.parameter_prefix) - - # A parameter is "confirmed" when the next parameter/invoke boundary is visible, - # meaning the parameter's value won't grow further. - # For the last parameter in the invoke body, it's only confirmed if - # the invoke itself is bounded by a next invoke. - confirmed_count = 0 - for pi in range(len(param_positions)): - if pi + 1 < len(param_positions): - confirmed_count += 1 - elif invoke_is_bounded: - confirmed_count += 1 - - # Only emit newly confirmed parameters - if confirmed_count > self._param_count_in_invoke: - param_matches = list(self.parameter_regex.finditer(invoke_body)) - for i in range(self._param_count_in_invoke, min(confirmed_count, len(param_matches))): - match = param_matches[i] - param_name = match.group(1).strip() - is_string_attr = match.group(2) - param_value = match.group(3).strip() - - param_type = self._get_param_type(func_name, param_name, tools) - converted_value = self._convert_param_value(param_value, is_string_attr, param_type) - serialized_value = json.dumps(converted_value, ensure_ascii=False) - - if not self._json_started: - json_fragment = "{" + f'"{param_name}": {serialized_value}' - self._json_started = True - else: - json_fragment = f', "{param_name}": {serialized_value}' - - self._accumulated_params[param_name] = converted_value + # Send complete arguments (or remaining diff) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = args_json[sent:] + if argument_diff: calls.append( ToolCallItem( tool_index=self.current_tool_id, name=None, - parameters=json_fragment, + parameters=argument_diff, ) ) - self.streamed_args_for_tool[self.current_tool_id] += json_fragment + self.streamed_args_for_tool[self.current_tool_id] += argument_diff - self._param_count_in_invoke = confirmed_count + try: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": json.loads(args_json), + } + except json.JSONDecodeError: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } - # Check if next invoke has started (meaning current one is complete) - if invoke_is_bounded: - # Current invoke is complete, close JSON and advance - if self._json_started: - close_fragment = "}" - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=close_fragment, - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += close_fragment + # Remove processed invoke from buffer + invoke_end_pos = current_text.find(self.invoke_end_token, complete_invoke_match.start()) + if invoke_end_pos != -1: + self._buffer = current_text[invoke_end_pos + len(self.invoke_end_token) :] else: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="{}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] = "{}" + self._buffer = current_text[complete_invoke_match.end() :] - # Update prev_tool_call_arr - self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params - - # Advance to next invoke, prune consumed buffer content - # Reset _invoke_count to 0 since buffer positions are now relative - self._buffer = current_text[invoke_end:] - self._invoke_count = 0 self.current_tool_id += 1 - self.current_tool_name_sent = False self._last_arguments = "" - self._accumulated_params = {} - self._param_count_in_invoke = 0 - self._json_started = False + self.current_tool_name_sent = False + self._accumulated_params = [] + self.streamed_args_for_tool.append("") + + return StreamingParseResult(normal_text="", calls=calls) + + # Partial invoke: name is known but parameters are still streaming + partial_match = self.partial_invoke_regex.search(current_text) + if partial_match: + func_name = partial_match.group(1) + partial_body = partial_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if not self.current_tool_name_sent: + if func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments as complete parameters are parsed + param_matches = self.param_regex.findall(partial_body) + if param_matches and len(param_matches) > len(self._accumulated_params): + self._accumulated_params = param_matches + current_args_json = self._dsml_params_to_json(param_matches) + + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(current_args_json) + except json.JSONDecodeError: + pass return StreamingParseResult(normal_text="", calls=calls) @@ -2020,19 +1810,3 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: final_normal_text = sp_result.normal_text return final_normal_text, final_calls - - def finalize_stream(self) -> Tuple[str, list[ToolCallItem]]: - """Finalize streaming when generation ends. - - For detectors that lack an explicit end-of-tool-call token (like DSML), - this closes any pending tool call JSON. For other detectors, this is a no-op. - - Returns: - A tuple of (normal_text, calls) like parse_stream_chunk. - """ - if not self.tools: - return "", [] - if hasattr(self.detector, "finalize_streaming"): - sp_result = self.detector.finalize_streaming(self.tools) - return sp_result.normal_text, sp_result.calls - return "", [] diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c51774898..88b099459 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -241,7 +241,8 @@ def match_prefix(self, key, update_refs=False): value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) return tree_node, len(value), value else: - self.dec_node_ref_counter(self.root_node) + if update_refs: + self.dec_node_ref_counter(self.root_node) return None, 0, None def _match_prefix_helper( diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e..ca3901ebd 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -3,6 +3,7 @@ from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -16,7 +17,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_mask_eos_reqs, is_all_greedy, ) = _get_post_sample_tensors(reqs) - eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) + eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager @@ -128,12 +129,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]): is_all_greedy = False req_idxes.append(req_obj.req_idx) - req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) - temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) - top_ps_cpu = torch.tensor(top_ps, dtype=torch.float, device="cpu", pin_memory=True) - top_ks_cpu = torch.tensor(top_ks, dtype=torch.int32, device="cpu", pin_memory=True) - length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) - mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) + temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) + top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) + length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( + key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 + ) + mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) return ( req_idxes_cpu.cuda(non_blocking=True), diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425..29385c23f 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -44,6 +44,17 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" + # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with + # a Python-based apply_chat_template that uses encoding_dsv32.py. + if tokenizer_mode == "deepseek_v32": + from ..models.deepseek3_2.model import DeepSeekV32Tokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs + ) + logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV32Tokenizer(hf_tokenizer) + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") From 5baa683090b5ecb78f12a11cc32251cd5d1d0965 Mon Sep 17 00:00:00 2001 From: Developer Date: Wed, 4 Feb 2026 13:50:50 +0000 Subject: [PATCH 20/23] fix --- .../layer_infer/transformer_layer_infer.py | 5 --- .../layer_weights/nsa_indexer_layer_weight.py | 2 - lightllm/models/deepseek3_2/mem_manager.py | 4 -- lightllm/models/deepseek3_2/model.py | 45 +++++++++++++------ 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9dba923cc..13a0c1394 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,3 @@ -from typing import override - import torch from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer @@ -33,7 +31,6 @@ def _get_nsa_backend(self): self._nsa_backend = self._nsa_backend_class(model=None) return self._nsa_backend - @override def _get_qkv( self, input: torch.Tensor, @@ -68,7 +65,6 @@ def _get_qkv( ) return q, cache_kv - @override def _context_attention_kernel( self, q: torch.Tensor, @@ -104,7 +100,6 @@ def _context_attention_kernel( ) return mla_out - @override def _token_attention_kernel( self, q, diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 6df1a8821..023b89979 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -11,14 +11,12 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__(layer_num, data_type, network_config, quant_cfg) return - @override def _parse_config(self): self.q_lora_rank = self.network_config_["q_lora_rank"] self.index_n_heads = self.network_config_["index_n_heads"] self.index_head_dim = self.network_config_["index_head_dim"] self.hidden_size = self.network_config_["hidden_size"] - @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index dc78f1de4..fdb2e87c6 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,4 +1,3 @@ -from typing_extensions import override import torch from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager @@ -15,16 +14,13 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) - @override def get_cell_size(self): return super().get_cell_size() + 132 - @override def _free_buffers(self): super()._free_buffers() self.indexer_ks_buffer = None - @override def resize_mem(self, new_size): super().resize_mem(new_size) self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index f907b0bed..77804096b 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,17 +1,26 @@ import copy import json import logging +import os from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager _logger = logging.getLogger(__name__) +# When ENABLE_NSA is set, use the full V32 NSA (Native Sparse Attention) pipeline +# including the indexer, custom memory manager, and NSA-aware attention kernels. +# When not set, fall back to the DeepSeek V3 (Deepseek2) inference path while +# keeping V32-specific tokenizer/parser support intact. +_ENABLE_NSA = os.environ.get("ENABLE_NSA", "0").lower() in ("1", "true") + +if _ENABLE_NSA: + from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight + from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer + from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo + from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager + class DeepSeekV32Tokenizer: """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based @@ -105,24 +114,32 @@ def apply_chat_template( @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): - # weight class - transformer_weight_class = Deepseek3_2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer - - # infer state class - infer_state_class = Deepseek3_2InferStateInfo + # When ENABLE_NSA is set, override with V32-specific NSA classes. + # Otherwise, inherit the V3/V2 classes from Deepseek2TpPartModel. + if _ENABLE_NSA: + transformer_weight_class = Deepseek3_2TransformerLayerWeight + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + infer_state_class = Deepseek3_2InferStateInfo def __init__(self, kvargs): super().__init__(kvargs) - self.index_topk = self.config["index_topk"] + if _ENABLE_NSA: + self.index_topk = self.config["index_topk"] + else: + _logger.info("ENABLE_NSA is not set, using DeepSeek V3 inference path (no NSA indexer).") return def _init_inferstate_cls(self): - self.infer_state_class = Deepseek3_2InferStateInfo + if _ENABLE_NSA: + self.infer_state_class = Deepseek3_2InferStateInfo + else: + super()._init_inferstate_cls() def _init_mem_manager(self): + if not _ENABLE_NSA: + # Fall back to the standard V3/V2 memory manager (no indexer buffer). + return super()._init_mem_manager() + manager_class = Deepseek3_2MemoryManager if get_env_start_args().llm_kv_type == "fp8kv": manager_class = Deepseek3_2FP8KVMemoryManager From a2298e92fa6b4b712968afbd4d59190fea9d629c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 4 Feb 2026 14:33:57 +0000 Subject: [PATCH 21/23] deepseekv32 model_type condition --- lightllm/server/tokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 29385c23f..b5b514858 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -43,10 +43,12 @@ def get_tokenizer( *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + model_type = model_cfg.get("model_type", "") """Gets a tokenizer for the given model name via Huggingface.""" # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with # a Python-based apply_chat_template that uses encoding_dsv32.py. - if tokenizer_mode == "deepseek_v32": + if model_type == "deepseek_v32": from ..models.deepseek3_2.model import DeepSeekV32Tokenizer hf_tokenizer = AutoTokenizer.from_pretrained( @@ -86,8 +88,6 @@ def get_tokenizer( "slowdown. Consider using a fast tokenizer instead." ) - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) - model_type = model_cfg.get("model_type", "") if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor From 785d3f77b48444443adeb37184c7e62945f552aa Mon Sep 17 00:00:00 2001 From: Developer Date: Thu, 5 Feb 2026 09:14:26 +0000 Subject: [PATCH 22/23] fix v1 streaming --- lightllm/server/api_openai.py | 40 +++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d94..349ae0334 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -338,6 +338,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None + has_emitted_tool_calls = False from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 @@ -358,7 +359,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if reasoning_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(reasoning_content=reasoning_text), + delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), finish_reason=None, ) chunk = ChatCompletionStreamResponse( @@ -386,8 +387,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if normal_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason if finish_reason else None, + delta=DeltaMessage(role="assistant", content=normal_text), + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -400,6 +401,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_item in calls: + has_emitted_tool_calls = True # transform call_item -> FunctionResponse + ToolCall if finish_reason == "stop": latest_delta_len = 0 @@ -436,7 +438,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choice_data = ChatCompletionStreamResponseChoice( index=0, delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -446,22 +448,34 @@ async def stream_results() -> AsyncGenerator[bytes, None]: ) yield f"data: {chunk.model_dump_json()}\n\n" else: - group_request_id = convert_sub_id_to_group_id(sub_req_id) - delta_message = DeltaMessage(role="assistant", content=delta) - if finish_status.is_finished(): - finish_reason = finish_status.get_finish_reason() - stream_choice = ChatCompletionStreamResponseChoice( - index=0, delta=delta_message, finish_reason=finish_reason - ) + stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) stream_resp = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") - # Additional usage chunk + yield f"data: {stream_resp.model_dump_json()}\n\n" + + # Determine final finish_reason: override to "tool_calls" if tool calls were emitted + if has_emitted_tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + # Final empty chunk containing only finish_reason (and role) + if finish_reason is not None: + final_choice = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(), + finish_reason=finish_reason, + ) + final_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[final_choice], + ) + yield f"data: {final_chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( From 2aed50f383a7a6fdc827ca37012be8ab69fa1bcb Mon Sep 17 00:00:00 2001 From: Developer Date: Thu, 5 Feb 2026 09:19:11 +0000 Subject: [PATCH 23/23] exclude_none --- lightllm/server/api_openai.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 349ae0334..ee8a35fd6 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -368,7 +368,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" if request.tool_choice != "none" and request.tools: if index not in parser_dict: @@ -396,7 +396,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) @@ -446,7 +446,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" else: delta_message = DeltaMessage(role="assistant", content=delta) stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) @@ -456,7 +456,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[stream_choice], ) - yield f"data: {stream_resp.model_dump_json()}\n\n" + yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" # Determine final finish_reason: override to "tool_calls" if tool calls were emitted if has_emitted_tool_calls and finish_reason == "stop": @@ -475,7 +475,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[final_choice], ) - yield f"data: {final_chunk.model_dump_json()}\n\n" + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( @@ -490,7 +490,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -691,7 +691,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)