diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df545498..0eea52cc89 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -10,9 +10,14 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + from .create_utils import ( get_prefill_att_backend_class, get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_nsa_prefill_att_backend_class, + get_nsa_decode_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca84..1286a46ec2 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,11 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + # nsa (native sparse attention) 专用传参项 + nsa_prefill: bool = False + nsa_prefill_dict: Dict = None + nsa_decode: bool = False + nsa_decode_dict: Dict = None @dataclass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13a..e3bf81daed 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -16,6 +16,7 @@ from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend logger = init_logger(__name__) @@ -46,6 +47,13 @@ }, } +nsa_data_type_to_backend = { + "None": { + "flashmla_sparse": NsaFlashMlaSparseAttBackend, + # Future backends: "fa3", "tilelang", "aiter" + }, +} + def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -105,3 +113,19 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + + +def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] + + +def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py new file mode 100644 index 0000000000..11a1ebfdcd --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -0,0 +1,13 @@ +"""NSA (Native Sparse Attention) backend implementations.""" + +from .flashmla_sparse import ( + NsaFlashMlaSparseAttBackend, + NsaFlashMlaSparsePrefillAttState, + NsaFlashMlaSparseDecodeAttState, +) + +__all__ = [ + "NsaFlashMlaSparseAttBackend", + "NsaFlashMlaSparsePrefillAttState", + "NsaFlashMlaSparseDecodeAttState", +] diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py new file mode 100644 index 0000000000..2c347ed32b --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -0,0 +1,134 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py +# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library. + +import dataclasses +import torch +from typing import Tuple, TYPE_CHECKING + +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaSparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": + return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState": + return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState): + """Prefill attention state for NSA using flash_mla_sparse_fwd.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + + return self._nsa_prefill_att(q=q, kv=k, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + nsa_dict = att_control.nsa_prefill_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + nsa_cache_seqlens = nsa_dict["nsa_cache_seqlens"] + nsa_cu_seqlens_k = nsa_dict["nsa_cu_seqlens_k"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] + + q_nope, q_rope = q + + # Extract k_rope and kv_nope from the KV buffer + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=nsa_cache_seqlens, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=nsa_cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + return o_tensor diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03c..8f54e14a72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -98,7 +98,7 @@ def _init_parallel_params(self): self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ if self.enable_ep_moe: assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" - logger.info( + logger.debug( f"global_rank {self.global_rank_} layerindex {self.layer_num_} " f"redundancy_expertids: {self.redundancy_expert_ids}" ) diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 86a887a259..0f4d6b13ae 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,11 +36,20 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() + def get_quant_method(self, name): return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 40c8aa993e..33bdca4475 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -7,6 +7,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager logger = init_logger(__name__) @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req): else: self.req_to_out_token_id_counter[req.req_idx].fill_(0) if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics: - prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True) + prompt_ids = g_pin_mem_manager.gen_from_list( + key="prompt_ids_for_penalty", + data=req.shm_req.get_prompt_ids_numpy(), + dtype=torch.int32, + ).cuda(non_blocking=True) token_id_counter( prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): cum_sum_len += len(id_to_count) p_cumsum_seq_len.append(cum_sum_len) - from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager - - p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_ids", size=len(p_token_ids), dtype=torch.int32 - ) - p_token_ids_tensor.numpy()[:] = p_token_ids - - p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_token_counts", size=len(p_token_counts), dtype=torch.int32 + p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32) + p_token_counts_tensor = g_pin_mem_manager.gen_from_list( + key="p_token_counts", data=p_token_counts, dtype=torch.int32 ) - p_token_counts_tensor.numpy()[:] = p_token_counts - - p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor( - key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32 + p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list( + key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32 ) - p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len return ( p_token_ids_tensor.cuda(non_blocking=True), diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 095f736791..fdd277f369 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/encoding_dsv32.py b/lightllm/models/deepseek3_2/encoding_dsv32.py new file mode 100644 index 0000000000..3ac4b83714 --- /dev/null +++ b/lightllm/models/deepseek3_2/encoding_dsv32.py @@ -0,0 +1,429 @@ +# Adapted from vLLM's deepseek_v32_encoding.py +# (https://github.com/vllm-project/vllm), which was originally adapted from +# https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +# +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import json +import re +from typing import Any + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... + + +... + +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: + +{tool_schemas} + +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n' +tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n" + +tool_output_template: str = "\n{content}" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments(tool_name, tool_args): + def _decode_value(key, value, string): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools): + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages): + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message(index, messages, thinking_mode): + if not (0 <= index < len(messages)): + raise ValueError(f"Index {index} out of range for messages list of length {len(messages)}") + if thinking_mode not in ["chat", "thinking"]: + raise ValueError(f"Invalid thinking_mode `{thinking_mode}`") + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning = msg.get("reasoning") + is_prefix = msg.get("prefix", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + elif role == "developer": + if not content: + raise ValueError(f"Invalid message for role `{role}`: {msg}") + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + if not (index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant"): + raise ValueError(f"Invalid messages at {index}:\n{assistant_msg}") + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order): + raise ValueError("No tool calls but found tool output") + + if tool_call_order == 1: + prompt += "\n\n" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + if not (reasoning or tool_calls): + raise ValueError( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning/tool_calls `{msg}` after last user message" + ) + thinking_part = thinking_template.format(reasoning=reasoning or "") + thinking_end_token + + if not tool_calls and is_prefix: + prompt += summary_content + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages(messages, last_user_idx=None): + messages_wo_thinking = [] + last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages, + thinking_mode, + context=None, + drop_thinking=True, + add_default_bos_token=True, +): + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode) + + return prompt + + +def _read_until_stop(index, text, stop): + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index, text): + tool_calls = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token]) + if _ != ">\n": + raise RuntimeError("Tool call format error") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise RuntimeError("Missing special token") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise RuntimeError("Tool name format error") + tool_name = p_tool_name[0] + + tool_args = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"]) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise RuntimeError("Parameter format error") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise RuntimeError("Duplicate parameter name") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise RuntimeError("Parameter format error") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text, thinking_mode): + summary_content, reasoning, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise RuntimeError("Invalid thinking format") + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise RuntimeError("Invalid summary format") + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise RuntimeError("Unexpected content after tool calls") + + if not (len(text) == index and stop_token in [eos_token, None]): + raise RuntimeError("Unexpected content at end") + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + if sp_token in summary_content or sp_token in reasoning: + raise RuntimeError("Unexpected special token in content") + + return { + "role": "assistant", + "content": summary_content, + "reasoning": reasoning, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py new file mode 100644 index 0000000000..779c2fc2d2 --- /dev/null +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -0,0 +1,209 @@ +import torch +import weakref +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager + + +class Deepseek3_2InferStateInfo(Deepseek2InferStateInfo): + _shared_nsa_buffers = None + + def __init__(self): + super().__init__() + self.lengths = None + self.page_table_size_1 = None + self.ks = None + self.ke = None + self.nsa_cu_seqlens_k = None + self.index_topk = 2048 + return + + @classmethod + def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): + """Get or create pre-allocated buffers for CUDA graph execution""" + if cls._shared_nsa_buffers is None: + max_total_q_tokens = graph_max_batch_size * max_seq_len + max_total_tokens = graph_max_batch_size * max_seq_len + + cls._shared_nsa_buffers = [ + { + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), + }, + { + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), + }, + ] + return cls._shared_nsa_buffers + + def _check_use_cuda_graph_buffers(self): + if hasattr(self, "_model_ref"): + model = self._model_ref() + if ( + model is not None + and hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_kv_seq_len <= model.graph_max_len_in_batch + ): + return True + return False + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + + self._model_ref = weakref.ref(model) + + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) + self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer + + if self.is_prefill: + self._init_nsa_indexing_prefill() + else: + if self.b_ready_cache_len is None: + self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) + + use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() + buffer = None + + if use_cuda_graph_buffers: + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] + self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] + else: + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") + + self.nsa_cache_seqlens.copy_(self.b_kv_seq_len.clamp(max=self.index_topk)) + assert self.nsa_cache_seqlens.dtype == torch.int32 + + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) + self.nsa_cu_seqlens_k[0] = 0 + + self._init_nsa_indexing_decode(use_cuda_graph_buffers, buffer) + + def _init_nsa_indexing_decode(self, use_cuda_graph_buffers, buffer): + """Optimized NSA indexing for decode: b_q_seq_len=1 per request. + + In decode, each request generates exactly 1 token, so: + - total_q_len = batch_size (no .item() needed) + - ks[i] = cumsum_offset[i], ke[i] = cumsum_offset[i] + 1 + - lengths[i] = b_seq_len[i] + - No repeat_interleave, no token_in_req math needed. + """ + b_seq_len = self.b_seq_len + b_req_idx = self.b_req_idx + num_seq = self.batch_size + + # Cumulative seq_len offsets for ks/ke: [0, s0, s0+s1, ...] + cum_seq = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) + + if use_cuda_graph_buffers: + model = self._model_ref() + max_seq_len = model.graph_max_len_in_batch + + # ks, ke, lengths — write directly into buffer slices + buf_ks = buffer["ks"][:num_seq] + buf_ke = buffer["ke"][:num_seq] + buf_lengths = buffer["lengths"][:num_seq] + + # ks[0] = 0, ks[i] = cum_seq[i-1] + buf_ks[0] = 0 + if num_seq > 1: + buf_ks[1:].copy_(cum_seq[: num_seq - 1]) + # ke = ks + 1 + torch.add(buf_ks, 1, out=buf_ke) + # lengths = b_seq_len + buf_lengths.copy_(b_seq_len.int()) + + self.ks = buf_ks + self.ke = buf_ke + self.lengths = buf_lengths + + # page_table: zero buffer slice, then fill valid entries + page_table = buffer["page_table_size_1"][:num_seq, :max_seq_len] + page_table.zero_() + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + # req_all_mem_index: use padded [num_seq * max_seq_len] layout + # Downstream uses ks/ke masking so padded entries are safe + max_total_seq = num_seq * max_seq_len + buf_mem = buffer["req_all_mem_index"][:max_total_seq] + buf_mem.copy_(all_rows.reshape(-1)) + self.req_all_mem_index = buf_mem + else: + # Non-CUDA-graph decode: simplified formulas, fresh tensors + max_seq_len = b_seq_len.max().item() + + # ks/ke/lengths + seq_offsets = torch.empty_like(cum_seq) + seq_offsets[0] = 0 + if num_seq > 1: + seq_offsets[1:] = cum_seq[:-1] + + self.ks = seq_offsets + self.ke = (seq_offsets + 1).int() + self.lengths = b_seq_len.int() + + # page_table and req_all_mem_index + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=b_seq_len.device) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + self.req_all_mem_index = all_rows[valid_mask] + + def _init_nsa_indexing_prefill(self): + """NSA indexing for prefill: variable q lengths, generic vectorized path.""" + b_seq_len = self.b_seq_len + b_q_seq_len = self.b_q_seq_len + b_req_idx = self.b_req_idx + num_seq = b_req_idx.shape[0] + device = b_seq_len.device + + max_seq_len = b_seq_len.max().item() + total_q_len = b_q_seq_len.sum().item() + + # page_table_size_1 and req_all_mem_index + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + self.req_all_mem_index = all_rows[valid_mask] + + # ks, ke, lengths — generic vectorized for variable q lengths + cum_seq = torch.cumsum(b_seq_len, dim=0) + seq_offsets = torch.zeros_like(cum_seq) + seq_offsets[1:] = cum_seq[:-1] + + req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) + + cum_q = torch.cumsum(b_q_seq_len, dim=0) + q_offsets = torch.zeros_like(cum_q) + q_offsets[1:] = cum_q[:-1] + token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] + + self.ks = seq_offsets[req_indices].int() + self.ke = (seq_offsets[req_indices] + token_in_req + 1).int() + self.lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py new file mode 100644 index 0000000000..7a9aeb46c9 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -0,0 +1,152 @@ +from sgl_kernel import fast_topk_transform_fused +import deep_gemm +import torch +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class NSAIndexerInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.index_topk = network_config["index_topk"] + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale + + return + + def ref_fp8_mqa_logits( + self, + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + cost_only: bool = False, + ): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + + def get_indices( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: NSAIndexerWeight, + ) -> torch.Tensor: + + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + destindex_copy_indexer_ks( + k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_] + ) + + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale + + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 + + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index + ) + + # Get actual sequence length from q (which comes from q_lora) + # This may differ from ks.shape[0] during certain operations + actual_seq_len = q.shape[0] + + # ks, ke, lengths, and weights should all match actual_seq_len + # Slice them if they don't match + if ks.shape[0] != actual_seq_len: + ks = ks[:actual_seq_len] + ke = ke[:actual_seq_len] + lengths = lengths[:actual_seq_len] + weights = weights[:actual_seq_len] + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + return fast_topk_transform_fused( + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, + topk=self.index_topk, + ) + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size ** -0.5) + + def _get_q_k_bf16( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: NSAIndexerWeight, + ): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) + k = layer_weight.wk_proj_.mm(hidden_states) + + k = layer_weight.k_norm_(k, eps=self.eps) + + # Slice position_cos and position_sin to match actual token length + actual_seq_len = q.shape[0] + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos[:actual_seq_len], + infer_state.position_sin[:actual_seq_len], + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..13a0c1394f --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,137 @@ +import torch + +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo +from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config) + + self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_) + self.topk_indices = None + + # Initialize NSA attention backend (singleton, lazy initialization) + self._nsa_backend_class = get_nsa_prefill_att_backend_class() + self._nsa_backend = None + return + + def _get_nsa_backend(self): + """Get or create the NSA backend (lazy initialization).""" + if self._nsa_backend is None: + # NSA backend doesn't require model reference for basic operations + self._nsa_backend = self._nsa_backend_class(model=None) + return self._nsa_backend + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + # Model-specific q projection (uses layer weights) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "topk_indices": self.topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + }, + ) + + # Create prefill state and execute attention + nsa_backend = self._get_nsa_backend() + prefill_state = nsa_backend.create_att_prefill_state(infer_state) + prefill_state.init_state() + mla_out = prefill_state.prefill_att( + q=q_all, + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, + ) + return mla_out + + def _token_attention_kernel( + self, + q, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ): + # Model-specific q projection (uses layer weights) + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "topk_indices": self.topk_indices, + "nsa_cache_seqlens": infer_state.nsa_cache_seqlens, + "nsa_cu_seqlens_k": infer_state.nsa_cu_seqlens_k, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + "qk_rope_head_dim": self.qk_rope_head_dim, + }, + ) + + # Create decode state and execute attention + nsa_backend = self._get_nsa_backend() + decode_state = nsa_backend.create_att_decode_state(infer_state) + decode_state.init_state() + o_tensor = decode_state.decode_att( + q=(q_nope, q_rope), + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, + ) + return o_tensor diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py new file mode 100644 index 0000000000..023b89979b --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -0,0 +1,55 @@ +from typing_extensions import override + +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight + + +class NSAIndexerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _parse_config(self): + self.q_lora_rank = self.network_config_["q_lora_rank"] + self.index_n_heads = self.network_config_["index_n_heads"] + self.index_head_dim = self.network_config_["index_head_dim"] + self.hidden_size = self.network_config_["hidden_size"] + + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_head_dim], + weight_names=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = LayerNormWeight( + dim=self.index_head_dim, + weight_name=f"{prefix}.k_norm.weight", + data_type=self.data_type_, + bias_name=f"{prefix}.k_norm.bias", + ) + self.weights_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_n_heads], + weight_names=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..adcba51cc9 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,12 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, data_type, network_config, quant_cfg) + self.indexer_layer_weight = NSAIndexerWeight( + layer_num=layer_num, data_type=data_type, network_config=network_config, quant_cfg=quant_cfg + ) + return diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py new file mode 100644 index 0000000000..fdb2e87c6b --- /dev/null +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -0,0 +1,31 @@ +import torch + +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + +class IndexerKSBuffer: + def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): + self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) + + def get_cell_size(self): + return super().get_cell_size() + 132 + + def _free_buffers(self): + super()._free_buffers() + self.indexer_ks_buffer = None + + def resize_mem(self, new_size): + super().resize_mem(new_size) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) + + +class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 0000000000..77804096b1 --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,162 @@ +import copy +import json +import logging +import os + +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.utils.envs_utils import get_env_start_args + +_logger = logging.getLogger(__name__) + +# When ENABLE_NSA is set, use the full V32 NSA (Native Sparse Attention) pipeline +# including the indexer, custom memory manager, and NSA-aware attention kernels. +# When not set, fall back to the DeepSeek V3 (Deepseek2) inference path while +# keeping V32-specific tokenizer/parser support intact. +_ENABLE_NSA = os.environ.get("ENABLE_NSA", "0").lower() in ("1", "true") + +if _ENABLE_NSA: + from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight + from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer + from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo + from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager + + +class DeepSeekV32Tokenizer: + """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based + encoding_dsv32 module instead of Jinja chat templates. + + DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat + template, so ``apply_chat_template`` would fail without either a manually + supplied ``--chat_template`` file or this wrapper. Activate it with + ``--tokenizer_mode deepseek_v32``. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + # Cache added vocabulary for performance (HuggingFace can be slow). + self._added_vocab = None + + # ------------------------------------------------------------------ + # Attribute delegation – everything not overridden goes to the inner + # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. + # ------------------------------------------------------------------ + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab + + # ------------------------------------------------------------------ + # Core override: route apply_chat_template through encode_messages. + # ------------------------------------------------------------------ + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + **kwargs, + ): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools + + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + # Deep copy to avoid mutating the caller's messages. + msgs = copy.deepcopy(msgs) + + # Determine thinking mode. + thinking_mode = "thinking" if thinking else "chat" + + # Inject tools into the first system message (or create one) so that + # encode_messages / render_message picks them up. + if tools: + # build_prompt passes tools as bare function dicts: + # [{"name": "f", "description": "...", "parameters": {...}}] + # encoding_dsv32's render_message expects OpenAI wrapper format: + # [{"type": "function", "function": {...}}] + wrapped_tools = [] + for t in tools: + if "function" in t: + wrapped_tools.append(t) + else: + wrapped_tools.append({"type": "function", "function": t}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + # Prepend a system message that carries the tools. + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + prompt = encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt + + +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # When ENABLE_NSA is set, override with V32-specific NSA classes. + # Otherwise, inherit the V3/V2 classes from Deepseek2TpPartModel. + if _ENABLE_NSA: + transformer_weight_class = Deepseek3_2TransformerLayerWeight + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + infer_state_class = Deepseek3_2InferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + if _ENABLE_NSA: + self.index_topk = self.config["index_topk"] + else: + _logger.info("ENABLE_NSA is not set, using DeepSeek V3 inference path (no NSA indexer).") + return + + def _init_inferstate_cls(self): + if _ENABLE_NSA: + self.infer_state_class = Deepseek3_2InferStateInfo + else: + super()._init_inferstate_cls() + + def _init_mem_manager(self): + if not _ENABLE_NSA: + # Fall back to the standard V3/V2 memory manager (no indexer buffer). + return super()._init_mem_manager() + + manager_class = Deepseek3_2MemoryManager + if get_env_start_args().llm_kv_type == "fp8kv": + manager_class = Deepseek3_2FP8KVMemoryManager + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + self.mem_manager = manager_class( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + mem_fraction=self.mem_fraction, + ) + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 0000000000..a4ecd0f518 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 0000000000..a345bd1e20 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,90 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + K_fp8, + K_scale, + DestLoc, + O_buffer, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + stride_o_bs, + stride_o_h, + stride_o_d, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + + # Store K_fp8 to O_buffer[:, 0, :128] + # Convert fp8 to uint8 through bitcast for storage in uint8 buffer + o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d + k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) + tl.store(o_k_ptrs, k_fp8_as_uint8) + + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) + # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation + o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d + scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) + + # Store each byte of the float32 scale (little-endian) + for i in range(4): + byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) + tl.store(o_scale_ptr + i * stride_o_d, byte_val) + + return + + +@torch.no_grad() +def destindex_copy_indexer_ks( + K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor +): + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + + # Handle cases where tensor lengths don't match (e.g., during prefix cache) + actual_seq_len = min(K_scale.shape[0], seq_len) + if actual_seq_len < seq_len: + K_fp8 = K_fp8[:actual_seq_len] + K_scale = K_scale[:actual_seq_len] + DestLoc = DestLoc[:actual_seq_len] + + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + + grid = (actual_seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + K_fp8, + K_scale, + DestLoc, + O_buffer, + K_fp8.stride(0), + K_fp8.stride(1), + K_scale.stride(0), + K_scale.stride(1), + O_buffer.stride(0), + O_buffer.stride(1), + O_buffer.stride(2), + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 0000000000..48bc34ad6e --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,82 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 + stride_i_bs, + stride_i_h, + stride_i_d, + stride_o_fp8_bs, + stride_o_fp8_d, + stride_o_scale_bs, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d + k_fp8_as_uint8 = tl.load(i_k_ptrs) + + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d + + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) + byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) + byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) + byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) + + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) + + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) + + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs + tl.store(o_scale_ptr, k_scale) + + return + + +@torch.no_grad() +def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + req_size = SrcLoc.shape[0] + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + + # Allocate output tensors + O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + + grid = (req_size,) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + I_buffer, + SrcLoc, + O_fp8, + O_scale, + I_buffer.stride(0), + I_buffer.stride(1), + I_buffer.stride(2), + O_fp8.stride(0), + O_fp8.stride(1), + O_scale.stride(0), + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + + return O_fp8, O_scale diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 0000000000..1c1f72b7d7 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,141 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, + KV_ptr, + KVScale_ptr, + Weights_ptr, + MemIndex_ptr, + CuSeqlenKs_ptr, + CuSeqlenKe_ptr, + Output_ptr, + seq_len, + seq_len_kv, + num_heads, + head_dim, + stride_q_seq, + stride_q_head, + stride_q_dim, + stride_kv_pool, + stride_kv_dim, + stride_w_seq, + stride_w_head, + stride_o_seq, + stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + for h in range(num_heads): + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + k = k * scales[:, None] + + score += tl.dot(q, tl.trans(k)) + score = tl.maximum(score, 0.0) + logits += score * weights[:, None] + + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + logits = tl.where(mask_valid, logits, float("-inf")) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None, +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv, + kv_scale, + weights, + mem_index, + cu_seqlen_ks, + cu_seqlen_ke, + output, + seq_len, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + output.stride(0), + output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 0000000000..8079864133 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,104 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py + +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 73b9bad4a4..877f029a75 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -92,9 +92,12 @@ def make_argument_parser() -> argparse.ArgumentParser: "--tokenizer_mode", type=str, default="fast", - help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, - slow mode is good for debug and test, fast mode get best performance, auto mode will - try to use fast mode, if failed will use slow mode""", + help="""tokenizer load mode, can be slow, fast, auto, or deepseek_v32. + slow mode load fast but run slow, good for debug and test. + fast mode get best performance. + auto mode will try to use fast mode, if failed will use slow mode. + deepseek_v32 mode wraps the tokenizer with Python-based DSML chat + template encoding for DeepSeek-V3.2 models (no --chat_template needed).""", ) parser.add_argument( "--load_way", @@ -128,7 +131,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d947..ee8a35fd66 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -338,6 +338,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None + has_emitted_tool_calls = False from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 @@ -358,7 +359,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if reasoning_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(reasoning_content=reasoning_text), + delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), finish_reason=None, ) chunk = ChatCompletionStreamResponse( @@ -367,7 +368,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" if request.tool_choice != "none" and request.tools: if index not in parser_dict: @@ -386,8 +387,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if normal_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason if finish_reason else None, + delta=DeltaMessage(role="assistant", content=normal_text), + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -395,11 +396,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_item in calls: + has_emitted_tool_calls = True # transform call_item -> FunctionResponse + ToolCall if finish_reason == "stop": latest_delta_len = 0 @@ -436,7 +438,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choice_data = ChatCompletionStreamResponseChoice( index=0, delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -444,24 +446,36 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" else: - group_request_id = convert_sub_id_to_group_id(sub_req_id) - delta_message = DeltaMessage(role="assistant", content=delta) - if finish_status.is_finished(): - finish_reason = finish_status.get_finish_reason() - stream_choice = ChatCompletionStreamResponseChoice( - index=0, delta=delta_message, finish_reason=finish_reason - ) + stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) stream_resp = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") - # Additional usage chunk + yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" + + # Determine final finish_reason: override to "tool_calls" if tool calls were emitted + if has_emitted_tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + # Final empty chunk containing only finish_reason (and role) + if finish_reason is not None: + final_choice = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(), + finish_reason=finish_reason, + ) + final_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[final_choice], + ) + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( @@ -476,7 +490,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -677,7 +691,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1d..3a8fddf744 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -29,7 +29,15 @@ logger = logging.getLogger(__name__) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>", "[TOOL_CALLS]", "<|tool▁calls▁begin|>"] +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", + "<|tool▁calls▁begin|>", + "<|DSML|function_calls>", +] class ToolCallItem(BaseModel): @@ -1443,6 +1451,272 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class DeepSeekV32Detector(BaseFormatDetector): + """ + Detector for DeepSeek V3.2 model function call format using DSML + (DeepSeek Markup Language). + + Format Structure: + ``` + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">Hangzhou + <|DSML|parameter name="date" string="true">2024-01-16 + + + ``` + + Key Components: + - Function Calls Block: `<|DSML|function_calls>` ... `` + - Individual Invocation: `<|DSML|invoke name="func">` ... `` + - Parameters: `<|DSML|parameter name="key" string="true|false">value` + - string="true": value is plain text (will be JSON-escaped) + - string="false": value is JSON (numbers, booleans, arrays, objects) + - Supports multiple parallel tool calls + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 + """ + + def __init__(self): + super().__init__() + self.dsml_token = "|DSML|" + self.bot_token = f"<{self.dsml_token}function_calls>" + self.eot_token = f"" + self.invoke_start_prefix = f"<{self.dsml_token}invoke" + self.invoke_end_token = f"" + self.param_end_token = f"" + + # Regex for complete invoke extraction + _de = re.escape(self.dsml_token) + self.invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for parameter extraction + self.param_regex = re.compile( + rf'<{_de}parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for partial invoke (name known, body still streaming) + self.partial_invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*)', + re.DOTALL, + ) + + self._last_arguments = "" + self._accumulated_params: List[tuple] = [] + self._in_function_calls = False # Track if we're inside a function_calls block + + def has_tool_call(self, text: str) -> bool: + return self.bot_token in text + + def _dsml_params_to_json(self, params: List[tuple]) -> str: + """Convert DSML parameter tuples (name, is_str, value) to a JSON arguments string.""" + args = {} + for name, is_str, value in params: + if is_str == "true": + args[name] = value + else: + try: + args[name] = json.loads(value) + except (json.JSONDecodeError, ValueError): + args[name] = value + return json.dumps(args, ensure_ascii=False) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """One-time parsing for DSML format tool calls.""" + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + tool_indices = self._get_tool_indices(tools) + calls = [] + + invoke_matches = self.invoke_regex.findall(text) + for func_name, invoke_body in invoke_matches: + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) + + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=args_json, + ) + ) + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """Streaming incremental parsing for DSML format tool calls.""" + self._buffer += new_text + current_text = self._buffer + + # Check if we're inside a function_calls block or starting one + has_tool = self.has_tool_call(current_text) or self._in_function_calls + + if not has_tool: + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() + + self._buffer = "" + for e_token in [self.eot_token, self.invoke_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) + + # Mark that we're inside a function_calls block + if self.has_tool_call(current_text): + self._in_function_calls = True + + # Check if function_calls block has ended + if self.eot_token in current_text: + self._in_function_calls = False + + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Try to find complete invoke blocks first + complete_invoke_match = self.invoke_regex.search(current_text) + if complete_invoke_match: + func_name = complete_invoke_match.group(1) + invoke_body = complete_invoke_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) + + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + + # Send complete arguments (or remaining diff) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = args_json[sent:] + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": json.loads(args_json), + } + except json.JSONDecodeError: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + # Remove processed invoke from buffer + invoke_end_pos = current_text.find(self.invoke_end_token, complete_invoke_match.start()) + if invoke_end_pos != -1: + self._buffer = current_text[invoke_end_pos + len(self.invoke_end_token) :] + else: + self._buffer = current_text[complete_invoke_match.end() :] + + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._accumulated_params = [] + self.streamed_args_for_tool.append("") + + return StreamingParseResult(normal_text="", calls=calls) + + # Partial invoke: name is known but parameters are still streaming + partial_match = self.partial_invoke_regex.search(current_text) + if partial_match: + func_name = partial_match.group(1) + partial_body = partial_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if not self.current_tool_name_sent: + if func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments as complete parameters are parsed + param_matches = self.param_regex.findall(partial_body) + if param_matches and len(param_matches) > len(self._accumulated_params): + self._accumulated_params = param_matches + current_args_json = self._dsml_params_to_json(param_matches) + + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(current_args_json) + except json.JSONDecodeError: + pass + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in DeepSeekV32 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1455,6 +1729,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "deepseekv32": DeepSeekV32Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c517748984..88b099459b 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -241,7 +241,8 @@ def match_prefix(self, key, update_refs=False): value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) return tree_node, len(value), value else: - self.dec_node_ref_counter(self.root_node) + if update_refs: + self.dec_node_ref_counter(self.root_node) return None, 0, None def _match_prefix_helper( diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e2ccf290e8..ca3901ebd0 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -3,6 +3,7 @@ from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -16,7 +17,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_mask_eos_reqs, is_all_greedy, ) = _get_post_sample_tensors(reqs) - eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True) + eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager @@ -128,12 +129,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]): is_all_greedy = False req_idxes.append(req_obj.req_idx) - req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True) - temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True) - top_ps_cpu = torch.tensor(top_ps, dtype=torch.float, device="cpu", pin_memory=True) - top_ks_cpu = torch.tensor(top_ks, dtype=torch.int32, device="cpu", pin_memory=True) - length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True) - mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True) + req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) + temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) + top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32) + top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32) + length_penalty_param_cpu = g_pin_mem_manager.gen_from_list( + key="length_penalty_param", data=length_penalty_param, dtype=torch.int32 + ) + mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) return ( req_idxes_cpu.cuda(non_blocking=True), diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index e0b2bd425e..b5b5148582 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -43,7 +43,20 @@ def get_tokenizer( *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + model_type = model_cfg.get("model_type", "") """Gets a tokenizer for the given model name via Huggingface.""" + # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with + # a Python-based apply_chat_template that uses encoding_dsv32.py. + if model_type == "deepseek_v32": + from ..models.deepseek3_2.model import DeepSeekV32Tokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs + ) + logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV32Tokenizer(hf_tokenizer) + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") @@ -75,8 +88,6 @@ def get_tokenizer( "slowdown. Consider using a fast tokenizer instead." ) - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) - model_type = model_cfg.get("model_type", "") if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor diff --git a/test/chat_template/tool_chat_template_deepseekv32.jinjia b/test/chat_template/tool_chat_template_deepseekv32.jinjia index b6d239dce7..7bb0fc375f 100644 --- a/test/chat_template/tool_chat_template_deepseekv32.jinjia +++ b/test/chat_template/tool_chat_template_deepseekv32.jinjia @@ -1,101 +1,202 @@ -{% if not add_generation_prompt is defined %} - {% set add_generation_prompt = false %} -{% endif %} -{% if not thinking is defined %} - {% set thinking = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %} -{%- for message in messages %} - {%- if message['role'] == 'system' %} - {%- if ns.is_first_sp %} - {% set ns.system_prompt = ns.system_prompt + message['content'] %} - {% set ns.is_first_sp = false %} - {%- else %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} - {%- endif %} - {% set ns.is_only_sys = true %} - {%- endif %} -{%- endfor %} - -{% if tools is defined and tools is not none %} - {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} - {% for tool in tools %} - {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} - {% endfor %} - {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} -{% endif %} - -{{ bos_token }}{{ ns.system_prompt }} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- set ns.is_tool = false -%} - {%- set ns.is_first = false -%} - {%- set ns.is_last_user = true -%} - {{'<|User|>' + message['content']}} - {%- endif %} - {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} - {%- if ns.is_last_user or ns.is_only_sys %} - {{'<|Assistant|>'}} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_first = false %} - {%- set ns.is_tool = false -%} - {%- for tool in message['tool_calls'] %} - {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %} - {%- if not ns.is_first %} - {%- if message['content'] is none %} - {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- else %} - {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- set ns.is_first = true -%} - {%- else %} - {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- endfor %} - {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} - {%- endif %} - {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} - {%- if ns.is_last_user %} - {{'<|Assistant|>'}} - {%- if message['prefix'] is defined and message['prefix'] and thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {%- endif %} - {%- if message['prefix'] is defined and message['prefix'] %} - {%- set ns.is_prefix = true -%} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- if ns.is_tool %} - {{message['content'] + '<|end▁of▁sentence|>'}} - {%- set ns.is_tool = false -%} - {%- else %} - {%- set content = message['content'] -%} - {%- if '' in content %} - {%- set content = content.split('', 1)[1] -%} - {%- endif %} - {{content + '<|end▁of▁sentence|>'}} - {%- endif %} - {%- endif %} - {%- if message['role'] == 'tool' %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_tool = true -%} - {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} - {%- endif %} - {%- if message['role'] != 'system' %} - {% set ns.is_only_sys = false %} - {%- endif %} +{#- ============================================================================ + DeepSeek-V3.2 DSML Chat Template + Converted from encoding_dsv32.py encode_messages function. + Uses DSML (DeepSeek Markup Language) format for tool calls. + ============================================================================ -#} +{%- set bos_token = "<|begin▁of▁sentence|>" -%} +{%- set eos_token = "<|end▁of▁sentence|>" -%} +{%- set thinking_start_token = "" -%} +{%- set thinking_end_token = "" -%} +{%- set dsml_token = "|DSML|" -%} + +{%- set system_msg_template = "{content}" -%} +{%- set user_msg_template = "<|User|>{content}<|Assistant|>" -%} +{%- set assistant_msg_template = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" -%} +{%- set thinking_template = "{reasoning_content}" -%} +{%- set tool_call_template = "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>" -%} +{%- set tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n{dsml_token}function_calls>" -%} +{%- set tool_output_template = "\n{content}" -%} + +{%- set TOOLS_SYSTEM_TEMPLATE -%} +## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{{ dsml_token }}function_calls>" block like the following as part of your reply to the user: + +<{{ dsml_token }}function_calls> +<{{ dsml_token }}invoke name="$FUNCTION_NAME"> +<{{ dsml_token }}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{{ dsml_token }}parameter> +... +{{ dsml_token }}invoke> +<{{ dsml_token }}invoke name="$FUNCTION_NAME2"> +... +{{ dsml_token }}invoke> +{{ dsml_token }}function_calls> + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +Here are the functions available in JSONSchema format: +{tool_schemas} +{%- endset -%} + +{%- if thinking_mode is not defined -%} + {%- set thinking_mode = "thinking" -%} +{%- endif -%} +{%- if drop_thinking is not defined -%} + {%- set drop_thinking = true -%} +{%- endif -%} +{%- if add_default_bos_token is not defined -%} + {%- set add_default_bos_token = true -%} +{%- endif -%} + +{#- Macro: encode_arguments_to_dsml -#} +{%- macro encode_arguments_to_dsml(arguments) -%} + {%- set ns = namespace(P_dsml_strs=[]) -%} + {%- if arguments is mapping -%} + {%- for k, v in arguments.items() -%} + {%- if v is string -%} + {%- set is_str = "true" -%} + {%- set value = v -%} + {%- else -%} + {%- set is_str = "false" -%} + {%- set value = v | tojson -%} + {%- endif -%} + {%- set p_dsml_str = "<" ~ dsml_token ~ "parameter name=\"" ~ k ~ "\" string=\"" ~ is_str ~ "\">" ~ value ~ dsml_token ~ "parameter>" -%} + {%- set ns.P_dsml_strs = ns.P_dsml_strs + [p_dsml_str] -%} + {%- endfor -%} + {%- endif -%} + {{- ns.P_dsml_strs | join("\n") -}} +{%- endmacro -%} + +{#- Macro: render_tools -#} +{%- macro render_tools(tools) -%} + {%- set ns = namespace(tools_json=[]) -%} + {%- for tool in tools -%} + {%- if tool.function is defined -%} + {%- set ns.tools_json = ns.tools_json + [tool.function | tojson] -%} + {%- else -%} + {%- set ns.tools_json = ns.tools_json + [tool | tojson] -%} + {%- endif -%} + {%- endfor -%} + {{- TOOLS_SYSTEM_TEMPLATE | replace("{tool_schemas}", ns.tools_json | join("\n")) }} +{% endmacro -%} + +{#- Macro: find_last_user_index -#} +{%- macro find_last_user_index(messages) -%} + {%- set ns = namespace(last_user_index=-1) -%} + {%- for msg in messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role in ['user', 'developer'] -%} + {%- set ns.last_user_index = loop.index0 -%} + {%- endif -%} + {%- endfor -%} + {{- ns.last_user_index -}} +{%- endmacro -%} + +{#- Macro: render_tool_calls_content -#} +{%- macro render_tool_calls_content(tool_calls) -%} + {%- set ns = namespace(formatted_calls=[]) -%} + {%- for tool_call in tool_calls -%} + {%- if tool_call.function is defined -%} + {%- set name = tool_call.function.name -%} + {%- set arguments = tool_call.function.arguments -%} + {%- else -%} + {%- set name = tool_call.name -%} + {%- set arguments = tool_call.arguments -%} + {%- endif -%} + {%- if arguments is string -%} + {%- set arguments = arguments | fromjson -%} + {%- endif -%} + {%- set formatted_call = "<" ~ dsml_token ~ "invoke name=\"" ~ name ~ "\">\n" ~ encode_arguments_to_dsml(arguments) ~ "\n" ~ dsml_token ~ "invoke>" -%} + {%- set ns.formatted_calls = ns.formatted_calls + [formatted_call] -%} + {%- endfor -%} + {{- "<" ~ dsml_token ~ "function_calls>\n" ~ ns.formatted_calls | join("\n") ~ "\n" ~ dsml_token ~ "function_calls>" -}} +{%- endmacro -%} + +{#- Macro: render_message -#} +{%- macro render_message(index, messages, thinking_mode) -%} + {%- set msg = messages[index] -%} + {%- set last_user_idx = find_last_user_index(messages) | int -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- set content = msg.content if msg.content is defined else (msg.get('content', '') or '') -%} + {%- set msg_tools = msg.tools if msg.tools is defined else msg.get('tools', []) -%} + {%- set tool_calls = msg.tool_calls if msg.tool_calls is defined else msg.get('tool_calls', []) -%} + {%- set reasoning_content = msg.reasoning_content if msg.reasoning_content is defined else (msg.get('reasoning_content', '') or '') -%} + + {%- if role == 'system' -%} + {{- content or '' -}} + {%- if msg_tools -%} + {{- "\n\n" ~ render_tools(msg_tools) -}} + {%- endif -%} + + {%- elif role == 'user' -%} + {{- "<|User|>" ~ content ~ "<|Assistant|>" -}} + {%- if index == last_user_idx and thinking_mode == "thinking" -%} + {{- thinking_start_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} + + {%- elif role == 'tool' -%} + {%- set ns = namespace(prev_assistant_idx=-1) -%} + {%- for i in range(index - 1, -1, -1) -%} + {%- set check_role = messages[i].role if messages[i].role is defined else messages[i].get('role') -%} + {%- if check_role != 'tool' and ns.prev_assistant_idx == -1 -%} + {%- set ns.prev_assistant_idx = i -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_call_order = index - ns.prev_assistant_idx -%} + {%- set assistant_msg = messages[ns.prev_assistant_idx] -%} + {%- set assistant_tool_calls = assistant_msg.tool_calls if assistant_msg.tool_calls is defined else assistant_msg.get('tool_calls', []) -%} + {%- if tool_call_order == 1 -%} + {{- "\n\n" -}} + {%- endif -%} + {{- "\n" ~ content -}} + {%- if tool_call_order == (assistant_tool_calls | length) -%} + {{- "\n" -}} + {%- if index >= last_user_idx and thinking_mode == "thinking" -%} + {{- "\n\n" ~ thinking_start_token -}} + {%- else -%} + {{- "\n\n" ~ thinking_end_token -}} + {%- endif -%} + {%- endif -%} + + {%- elif role == 'assistant' -%} + {%- set ns = namespace(thinking_part="", tool_calls_content="") -%} + {%- if tool_calls -%} + {%- set ns.tool_calls_content = "\n\n" ~ render_tool_calls_content(tool_calls) -%} + {%- endif -%} + {%- set summary_content = content or "" -%} + {%- if thinking_mode == "thinking" and index > last_user_idx -%} + {%- set ns.thinking_part = reasoning_content ~ thinking_end_token -%} + {%- endif -%} + {{- ns.thinking_part ~ summary_content ~ ns.tool_calls_content ~ "<|end▁of▁sentence|>" -}} + {%- endif -%} +{%- endmacro -%} + +{#- Main template body -#} +{%- set full_messages = messages -%} + +{#- Handle tools in top-level (OpenAI format) -#} +{%- if tools is defined and tools is not none -%} + {%- set ns_sys = namespace(has_system=false, sys_idx=-1) -%} + {%- for msg in full_messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role == 'system' and not ns_sys.has_system -%} + {%- set ns_sys.has_system = true -%} + {%- set ns_sys.sys_idx = loop.index0 -%} + {%- endif -%} + {%- endfor -%} +{%- endif -%} + +{%- if add_default_bos_token -%} + {{- bos_token -}} +{%- endif -%} + +{#- If tools defined at top level but no system message has them, prepend tools info -#} +{%- if tools is defined and tools is not none -%} + {{- render_tools(tools) -}} +{%- endif -%} + +{%- for msg in full_messages -%} + {{- render_message(loop.index0, full_messages, thinking_mode) -}} {%- endfor -%} -{% if add_generation_prompt and not ns.is_tool%} - {% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %} - {{'<|Assistant|>'}} - {%- if not thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {% endif %} -{% endif %} diff --git a/test/test_api/test_gsmk.py b/test/test_api/test_gsmk.py new file mode 100644 index 0000000000..2d9ead65b8 --- /dev/null +++ b/test/test_api/test_gsmk.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + +SYSTEM_PROMPT_TARGET_LEN = 18192 + + +def generate_system_prompt(): + """Generate a system prompt of approximately 8192 characters.""" + base = ( + "You are a highly capable math assistant. Your task is to solve grade school math problems step by step. " + "Show your reasoning clearly and provide the final numerical answer. " + "Break down each problem into smaller steps and verify your calculations. " + "Always end your answer with the format: #### . " + ) + # Repeat base text to reach target length + repeats = SYSTEM_PROMPT_TARGET_LEN // len(base) + 1 + prompt = (base * repeats)[:SYSTEM_PROMPT_TARGET_LEN] + return prompt + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument( + "--system-prompt", action="store_true", help="Prepend an 8192-character system prompt to each request" + ) + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + system_prefix = "" + if args.system_prompt: + system_prefix = generate_system_prompt() + "\n\n" + print(f"System prompt enabled: {len(system_prefix)} characters") + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=system_prefix + few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args)