diff --git a/.gitignore b/.gitignore index 63408699f4..3fb49db8b1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 89d90f3c12..219343e2c6 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F import triton -from typing import final, List +from typing import final, List, Optional from tqdm import tqdm from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -33,6 +33,10 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend @@ -91,6 +95,7 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) self.is_mtp_mode = self.args.mtp_mode in [ "vanilla_with_att", "eagle_with_att", @@ -104,19 +109,21 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - self._init_weights() - self._init_req_manager() - self._init_mem_manager() + with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup): + self._init_weights() + with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): + self._init_req_manager() + self._init_mem_manager() + self._init_kv_move_buffer() + # 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state # 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值 self.req_manager.mem_manager = self.mem_manager - - self._init_kv_move_buffer() self._check_mem_size() self._init_infer_layer() self._init_some_value() self._init_custom() - self._load_hf_weights() + self.load_weights(self.weight_dict) # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() @@ -181,17 +188,15 @@ def _init_weights(self, start_layer_index=0): ] return - def _load_hf_weights(self): + def load_weights(self, weight_dict: dict): + assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" load_hf_weights( self.data_type, - weight_dir=self.weight_dir_, + self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, + weight_dict=weight_dict, ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -999,6 +1004,7 @@ def _check_max_len_infer(self): ) logger.error(exception_str) raise Exception(exception_str) + torch.cuda.empty_cache() return def autotune_layers(self): @@ -1133,6 +1139,9 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output + del b_mtp_index + del b_prefill_start_loc + del b_q_seq_len torch.cuda.empty_cache() return @@ -1153,3 +1162,72 @@ def _gen_special_model_input(self, token_num: int): special_model_input["mtp_draft_input_hiddens"] = None return special_model_input + + def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): + torch.cuda.synchronize() + if tags is None: + self.release_all() + return + if MemoryTag.WEIGHT in tags: + self.release_weight() + if MemoryTag.KV_CACHE in tags: + self.release_kv_cache() + if MemoryTag.GRAPH in tags: + self.release_graph() + return + + def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): + if tags is None: + self.resume_all() + return + if MemoryTag.WEIGHT in tags: + self.resume_weight() + if MemoryTag.KV_CACHE in tags: + self.resume_kv_cache() + if MemoryTag.GRAPH in tags: + self.resume_graph() + return + + def release_weight(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + torch.cuda.empty_cache() + gc.collect() + + def release_kv_cache(self): + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + torch.cuda.empty_cache() + gc.collect() + + def release_graph(self): + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def release_all(self): + self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) + torch.cuda.empty_cache() + gc.collect() + + def resume_weight(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + + def resume_kv_cache(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + + def resume_graph(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) + + def resume_all(self): + torch.cuda.empty_cache() + gc.collect() + self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) + self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) + self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 2a099ffc75..348d9216b7 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -8,6 +8,10 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.torch_memory_saver_utils import ( + TorchMemorySaverWrapper, + MemoryTag, +) from .infer_struct import InferStateInfo @@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -96,7 +101,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): delattr(infer_state, param_name) with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) graph_obj.replay() @@ -134,7 +139,7 @@ def _capture_decode_overlap( with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) self.graph[batch_size] = ( graph_obj, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..6ca48299f0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,6 +33,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name @@ -50,6 +51,7 @@ def __init__( self.enable_ep_moe = get_env_start_args().enable_ep_moe self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts + self.moe_layer_index = moe_layer_index self._init_config(network_config) self._init_redundancy_expert_params() self._init_parallel_params() @@ -130,6 +132,7 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +148,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + moe_layer_index=self.moe_layer_index, + microbatch_index=microbatch_index, ) def low_latency_dispatch( @@ -295,6 +300,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=self.local_n_routed_experts, ) + self.w1, self.w3 = w13_param_list self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0]) self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) @@ -307,7 +313,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - + # for merged weights + self._load_merge_weight(weights) # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: @@ -332,6 +339,7 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" + row_slice_func = self.row_slicer._slice_weight col_slice_func = self.col_slicer._slice_weight if w1_weight in weights: @@ -341,6 +349,19 @@ def _load_expert( if w2_weight in weights: self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + def _load_merge_weight(self, weights: Dict[str, torch.Tensor]): + w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}" + w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}" + w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}" + row_slice_func = self.row_slicer._slice_weight + col_slice_func = self.col_slicer._slice_weight + if w1_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1) + if w2_merge_weight in weights: + self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2) + if w3_merge_weight in weights: + self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3) + def _load_expert_scale( self, expert_idx: int, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..4ca1605be4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -8,6 +8,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -46,6 +47,7 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + moe_layer_index: int = 0, ) -> None: network_config["norm_topk_prob"] = None super().__init__( @@ -62,6 +64,7 @@ def __init__( num_fused_shared_experts=num_fused_shared_experts, layer_num=layer_num, network_config=network_config, + moe_layer_index=moe_layer_index, ) self.hidden_size = network_config["hidden_size"] @@ -144,10 +147,15 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._router(router_logits, top_k) + # Rollout router replay + if _routing_mgr.g_routing_capture_manager is not None: + _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..1c93cb13dc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..90b525d275 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -3,6 +3,7 @@ from lightllm.common.quantization.no_quant import WeightPack from lightllm.common.quantization.quantize_method import QuantizationMethod from .base_impl import FuseMoeBaseImpl +from lightllm.common.basemodel import routing_manager as _routing_mgr class FuseMoeTriton(FuseMoeBaseImpl): @@ -125,6 +126,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + moe_layer_index: Optional[int] = None, + microbatch_index: int = 0, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -137,6 +140,10 @@ def __call__( num_expert_group=num_expert_group, scoring_func=scoring_func, ) + + if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None: + _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index) + output = self._fused_experts( input_tensor=input_tensor, w13=w13, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..15f050c14a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 -# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。 class RowSliceMixin(SliceMixinTpl): def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[0]) - return weight[start:end, :] + weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-2]) + return weight[..., start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert ( @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[0]) - return weight_scale[start:end] + weight_scale.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-2]) + return weight_scale[..., start:end, :] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[0]) - return weight_zero_point[start:end] + weight_zero_point.shape[-2] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-2]) + return weight_zero_point[..., start:end, :] class ColSliceMixin(SliceMixinTpl): @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert ( - weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight.shape[1]) - return weight[:, start:end] + weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[-1]) + return weight[..., start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ * self.repeat_times_ @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_scale.shape[1]) - return weight_scale[:, start:end] + ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[-1]) + return weight_scale[..., start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" - start, end = self._get_slice_start_end(weight_zero_point.shape[1]) - return weight_zero_point[:, start:end] + weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[-1]) + return weight_zero_point[..., start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 0000000000..77b611130f --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,191 @@ +import atexit +import torch +import numpy as np +from multiprocessing import shared_memory +from typing import Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_current_rank_in_dp +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +from lightllm.utils.envs_utils import get_unique_server_name + +logger = init_logger(__name__) + + +def routing_dtype_id_to_np(dtype_id: int): + if dtype_id == 1: + return np.uint8 + elif dtype_id == 2: + return np.int16 + return np.int32 + + +def get_routing_config_shm() -> SharedArray: + service_name = get_unique_server_name() + return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32) + + +class RoutingCaptureManager: + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, + ): + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.kv_cache_size = kv_cache_size + + self.dtype = torch.uint8 if num_experts <= 255 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.uint8 else 2 + + # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory. + # Written after forward() via flush_to_routing_buffer(), read on request finish. + routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.routing_buffer = torch.zeros( + (kv_cache_size, num_moe_layers, topk), + dtype=self.dtype, + device="cpu", + ) + + # Capture buffers: simple contiguous tensors written to during forward(). + capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes + self._capture_buffer = [ + torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2) + ] + + dtype_name = "uint8" if self.dtype == torch.uint8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, " + f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}" + ) + + @property + def np_dtype(self): + return np.uint8 if self.dtype == torch.uint8 else np.int16 + + @property + def dtype_id(self) -> int: + return 1 if self.dtype == torch.uint8 else 2 + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + num_tokens = topk_ids.shape[0] + self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype) + + def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None: + buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk) + self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu() + + def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray: + cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes + return self.routing_buffer[cpu_indexes, :, :].numpy() + + +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + kv_cache_size: int, + max_capture_tokens: int, +) -> None: + global g_routing_capture_manager + assert g_routing_capture_manager is None, "RoutingCaptureManager already exists" + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=kv_cache_size, + max_capture_tokens=max_capture_tokens, + ) + + +def cleanup_routing_shm_pool() -> None: + """Unlink all pre-allocated routing SHM segments. Called at server shutdown.""" + try: + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + except Exception: + return + + service_name = get_unique_server_name() + + for i in range(args.running_max_req_size): + name = f"{service_name}_shm_routing_{i}" + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except Exception: + pass + + config_name = f"{service_name}_routing_config" + try: + shm = shared_memory.SharedMemory(name=config_name) + shm.close() + shm.unlink() + except Exception: + pass + + +def init_routing_capture(model, num_moe_layers: int) -> None: + dp_rank = get_current_rank_in_dp() + logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}") + if dp_rank != 0: + logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}") + return + + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled." + ) + return + + num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0)) + topk = model.config.get("num_experts_per_tok", 0) + assert num_experts > 0 and topk > 0 + + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + + # Capture buffer must fit the max tokens in any single forward call. + # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size. + batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192 + max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + kv_cache_size=model.mem_manager.size + 1, + max_capture_tokens=max_capture_tokens, + ) + + mgr = g_routing_capture_manager + dtype_id = mgr.dtype_id + + max_req_total_len = args.max_req_total_len + + # Write config to cross-process SHM + shm = get_routing_config_shm() + shm.arr[0] = num_moe_layers + shm.arr[1] = topk + shm.arr[2] = dtype_id + shm.arr[3] = max_req_total_len + logger.info( + f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, " + f"dtype_id={dtype_id}, max_tokens={max_req_total_len}" + ) + atexit.register(cleanup_routing_shm_pool) diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 0000000000..353affd8ed --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..c75c871c72 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..14026090e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json new file mode 100644 index 0000000000..939c939523 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 2, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..13ba4ba8e5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ee316f610b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e027701092 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ddda23d257 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "67584": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..560ca6c09d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0713de7996 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..e950ff0954 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": true, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": true, + "num_stages": 4, + "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7f479b8382 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 2 + }, + "8448": { + "BLOCK_SIZE": 256, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b3051c6584 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "8448": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..fdb3212216 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a94e669353 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..441421fd5d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "67584": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..864d1d3f18 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bcf56e01f7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 128, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..ba1dc8a75d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "3072": { + "BLOCK_SIZE": 2048, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..6f109e1c6e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 32768, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..198a196dfb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..537c7a90eb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..9a6dcb6fbf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4096": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..df501847ec --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "5120": { + "BLOCK_SIZE": 1024, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..60c12e4c07 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -106,14 +106,10 @@ def __init__( self.configs_gen_func = configs_gen_func self.kernel_name = kernel_name - self.cache_dir = os.path.join( - Path(__file__).parent, - "autotune_kernel_configs", - get_triton_version(), - get_current_device_name(), - self.kernel_name, - ) - os.makedirs(self.cache_dir, exist_ok=True) + # cache_dir 依赖 get_current_device_name(),后者要求 torch.cuda.is_available()。 + # 这里 lazy 化,避免 CPU-only 的进程(例如 Ray driver / verl rollout replica + # 入口)在 import 时就触发 TypeError。 + self._cache_dir: Optional[str] = None self.fn = fn self.static_key_func = static_key_func self.run_key_func = run_key_func @@ -209,6 +205,25 @@ def __call__(self, *args, **kwargs): return self.fn(*args, **kwargs) + @property + def cache_dir(self) -> str: + if self._cache_dir is None: + device_name = get_current_device_name() + if device_name is None: + raise RuntimeError( + f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device " + f"to resolve its cache directory, but torch.cuda.is_available() is False." + ) + self._cache_dir = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + get_triton_version(), + device_name, + self.kernel_name, + ) + os.makedirs(self._cache_dir, exist_ok=True) + return self._cache_dir + def _try_load_cache(self, static_key): if static_key in self.cached_configs: return False diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..88c4b1e8ee 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -231,6 +231,7 @@ def _moe_ffn_tp( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -258,6 +259,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..bd72035072 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -242,6 +242,9 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + first_moe = self.network_config_["first_k_dense_replace"] + freq = self.network_config_.get("moe_layer_freq", 1) + moe_layer_index = (self.layer_num_ - first_moe) // freq self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -256,6 +259,7 @@ def _init_moe(self): num_fused_shared_experts=self.num_fused_shared_experts, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_ffn(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..79bd327068 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -49,6 +50,9 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 4c457fd993..bb9e6140bf 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -52,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) hidden_states = hidden_states.view(num_tokens, hidden_dim) return self._tpsp_reduce(input=hidden_states, infer_state=infer_state) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 7c8c30940e..7278c62fec 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -55,6 +55,7 @@ def _init_moe(self): num_fused_shared_experts=0, layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) def _init_weight_names(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 9e9561eb24..cff748933d 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class @@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_custom(self): + super()._init_custom() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) + def _init_att_backend(self): self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( model=self diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py deleted file mode 100644 index b0e27ac1de..0000000000 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -import functools -import json -import os -from typing import Any, Dict, Optional, Tuple - -import torch -import triton -import triton.language as tl -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -# Pytorch version -# Triton version in progress -def topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output, - topk=2, -): - scores = torch.softmax(gating_output, dim=-1) - topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False) - return topk_weights, topk_ids - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - alloc_tensor_func=torch.empty, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) - topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 0cf651598a..d90c631547 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -21,25 +18,14 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - ffn2_out = fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, - ) - return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index 51c62fd4cb..d93cb5fb58 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -57,4 +57,5 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=self.layer_num_, ) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e87..35bf38de58 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -45,6 +46,9 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + if self.args.enable_return_routed_experts: + num_moe_layers = len(self.trans_layers_weight) + init_routing_capture(self, num_moe_layers) return def _init_mem_manager(self): diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..c94135573b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -12,6 +12,7 @@ from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel import os +from typing import Union, List # Warp of the origal tokenizer class QWen2VLTokenizer(BaseMultiModalTokenizer): @@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem): def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError - def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): - - origin_ids = self.tokenizer.encode(prompt) + def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs): + if isinstance(prompt, str): + origin_ids = self.tokenizer.encode(prompt) + elif isinstance(prompt, list): + origin_ids = prompt + else: + raise ValueError(f"Unsupported prompt type: {type(prompt)}") # -> origin_ids = [token for token in origin_ids if token != self.image_token_id] diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py index fe4b1883bd..44425e7e10 100644 --- a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -12,7 +12,6 @@ def load_hf_weights(self, weights): def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int): layer_prefix = f"model.layers.{layer_num}." keys = list(weights.keys()) - num_experts = 0 for k in keys: if not k.startswith(layer_prefix): @@ -20,21 +19,8 @@ def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_s if "mlp.experts.gate_up_proj" in k: fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] - num_experts = fused_weight.shape[0] - prefix = k.rsplit(".gate_up_proj", 1)[0] gate_weight = fused_weight[:, :moe_intermediate_size, :] up_weight = fused_weight[:, moe_intermediate_size:, :] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] - weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] - - elif "mlp.experts.down_proj" in k: - down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] - num_experts = down_weight.shape[0] - - prefix = k.rsplit(".down_proj", 1)[0] - - for expert_idx in range(num_experts): - weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] + weights[f"{prefix}.gate_proj"] = gate_weight + weights[f"{prefix}.up_proj"] = up_weight diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..744ddc9d4f 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -85,6 +85,7 @@ def _moe_ffn_tp( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=infer_state.microbatch_index, ) return hidden_states.view(num_tokens, hidden_dim) @@ -104,6 +105,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=infer_state.microbatch_index, ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index e525cb2d20..5358229949 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -52,6 +52,11 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + mlp_only = set(self.network_config_.get("mlp_only_layers", [])) + step = self.network_config_.get("decoder_sparse_step", 1) + moe_layer_index = sum( + 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0 + ) self.experts = FusedMoeWeight( gate_proj_name="gate_proj", down_proj_name="down_proj", @@ -65,6 +70,7 @@ def _init_moe(self): quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), layer_num=self.layer_num_, network_config=self.network_config_, + moe_layer_index=moe_layer_index, ) def _init_qkv(self): diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index b71d7f4878..506f69e3ff 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -1,9 +1,14 @@ import torch from typing import final from lightllm.models.registry import ModelRegistry -from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer -from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import ( + Qwen3MOETransformerLayerInfer, +) +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import ( + Qwen3MOETransformerLayerWeight, +) from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.routing_manager import init_routing_capture from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -28,3 +33,6 @@ def _init_custom(self): # Only initialize DeepEP group for MoE models with num_experts if "num_experts" in self.config and self.config["num_experts"] > 0: dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + if self.args.enable_return_routed_experts: + num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe) + init_routing_capture(self, num_moe_layers) diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 9f83832a7e..b4be10d0d0 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -26,6 +26,7 @@ def _pre_init(self, kvargs: dict): return def _init_custom(self): + super()._init_custom() self._cos_cached = self.main_model._cos_cached self._sin_cached = self.main_model._sin_cached return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 70b638715e..871001d0cb 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -1,8 +1,7 @@ import argparse -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() +def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--run_mode", @@ -261,6 +260,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) + parser.add_argument( + "--lightllm_instance_id", + type=int, + default=0, + help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.", + ) parser.add_argument( "--use_config_server_to_init_nccl", action="store_true", @@ -747,6 +752,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_torch_memory_saver", + action="store_true", + help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""", + ) + parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""") parser.add_argument( "--disk_cache_dir", type=str, @@ -820,4 +831,10 @@ def make_argument_parser() -> argparse.ArgumentParser: If the op is not implemented for the platform and the hardware support triton, it will use triton implementation.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 4e2c723558..13904e44bd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -34,7 +34,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -50,6 +50,7 @@ from lightllm.utils.error_utils import ServerBusyError from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq from dataclasses import dataclass from .api_openai import chat_completions_impl, completions_impl @@ -61,6 +62,15 @@ ModelCard, ModelListResponse, ) +from .io_struct import ( + AbortReq, + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, + GeneralModelToHttpRpcRsp, +) from .build_prompt import build_prompt, init_tokenizer logger = init_logger(__name__) @@ -188,6 +198,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -389,6 +415,83 @@ async def metrics() -> Response: return response +@app.post("/abort_request") +async def abort_request(request: AbortReq, raw_request: Request): + """Abort a request.""" + try: + await g_objs.httpserver_manager.abort_request(request) + return Response(status_code=200) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +async def handle_request_common(request_obj, handler): + try: + ret: GeneralModelToHttpRpcRsp = await handler(request_obj) + if ret.success: + return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200) + else: + return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg) + except Exception as e: + logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") + + +@app.post("/init_weights_update_group") +async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request): + """Init weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group) + + +@app.post("/destroy_weights_update_group") +async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request): + """Destroy weights update group.""" + return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed) + + +@app.post("/update_weights_from_tensor") +async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request): + """Update model parameter from distributed online.""" + return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor) + + +@app.post("/flush_cache") +@app.get("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache) + + +@app.post("/pause_generation") +async def pause_generation(): + await g_objs.httpserver_manager.pause_generation() + return Response(content="Generation paused successfully.", status_code=200) + + +@app.post("/continue_generation") +async def continue_generation(): + await g_objs.httpserver_manager.continue_generation() + return Response(content="Generation continued successfully.", status_code=200) + + +@app.get("/release_memory_occupation") +@app.post("/release_memory_occupation") +async def release_memory_occupation(request: ReleaseMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation) + + +@app.get("/resume_memory_occupation") +@app.post("/resume_memory_occupation") +async def resume_memory_occupation(request: ResumeMemoryReq): + return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation) + + @app.websocket("/pd_register") async def register_and_keep_alive(websocket: WebSocket): await websocket.accept() diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 39a5808aab..28d57ccdc4 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] return_details = sample_params_dict.pop("return_details", False) + return_routed_experts = sample_params_dict.pop( + "return_routed_experts", httpserver_manager.args.enable_return_routed_experts + ) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() @@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if return_routed_experts and routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) @@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] _ = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_routed_experts", None) sampling_params = SamplingParams() sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) sampling_params.verify() diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 6e04d5d47e..5306ecb698 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,11 +1,22 @@ import torch -from .api_cli import make_argument_parser +from .api_cli import add_cli_args +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start + +def launch_server(args: StartArgs): + from .api_start import pd_master_start, normal_or_p_d_start, config_server_start, visual_only_start + + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e if args.run_mode == "pd_master": pd_master_start(args) @@ -15,3 +26,13 @@ visual_only_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + add_cli_args(parser) + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6f25e42c88..fe793829e5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import sys import time @@ -17,6 +18,7 @@ from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.utils.config_utils import has_audio_module, has_vision_module, is_linear_att_mixed_model logger = init_logger(__name__) @@ -53,9 +55,31 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") if http_server_process: @@ -63,12 +87,13 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs +def _set_envs_and_config(args: StartArgs): + mp.set_start_method("spawn", force=True) - args: StartArgs = args - set_unique_server_name(args) +def _launch_subprocesses(args: StartArgs): + + _set_envs_and_config(args) if args.enable_mps: from lightllm.utils.device_utils import enable_mps @@ -133,12 +158,6 @@ def normal_or_p_d_start(args): check_recommended_shm_size(args) assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] - # 确保单机上多实列不冲突 - if args.zmq_mode == "ipc:///tmp/": - zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" - args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 - args.zmq_mode = zmq_mode - logger.info(f"zmq mode head: {args.zmq_mode}") logger.info(f"use tgi api: {args.use_tgi_api}") @@ -184,12 +203,16 @@ def normal_or_p_d_start(args): # mtp params check if args.mtp_mode is not None: - assert args.mtp_draft_model_dir is not None + if args.mtp_draft_model_dir is None: + args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step assert args.mtp_step > 0 else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # automatically set visual_dp based on visual_tp and tp + if args.visual_tp < args.tp and args.tp % args.visual_tp == 0: + args.visual_dp = args.tp // args.visual_tp if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -324,6 +347,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp, + instance_id=args.lightllm_instance_id, used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") @@ -358,6 +382,16 @@ def normal_or_p_d_start(args): args.nccl_port = nccl_port if args.pd_decode_rpyc_port is None: args.pd_decode_rpyc_port = pd_decode_rpyc_port + + set_unique_server_name(args) + + # 确保单机上多实列不冲突 + if args.zmq_mode == "ipc:///tmp/": + zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_" + args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功 + args.zmq_mode = zmq_mode + logger.info(f"zmq mode head: {args.zmq_mode}") + args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -462,6 +496,13 @@ def normal_or_p_d_start(args): ], ) + return process_manager + + +def normal_or_p_d_start(args: StartArgs): + + process_manager = _launch_subprocesses(args) + # 启动 Hypercorn command = [ "hypercorn", @@ -497,7 +538,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -514,10 +555,7 @@ def pd_master_start(args): logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port( - num=1, - used_ports=[ - args.port, - ], + num=1, used_nccl_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id ) metric_port = can_use_ports[0] diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index efe24c53e3..e78d56fc4f 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -11,11 +11,11 @@ from typing import List asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.utils.log_utils import init_logger +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.multimodal_params import AudioItem from .model_infer import start_model_process, AudioModelRpcClient -from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name @@ -77,7 +77,7 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - def get_need_infer_audios(self, group_req_indexes: GroupReqIndexes) -> List[AudioItem]: + def get_need_infer_audios(self, group_req_indexes: GenerateReqIndex) -> List[AudioItem]: shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted disable_prompt_cache = shm_req.sample_params.disable_prompt_cache @@ -102,7 +102,7 @@ def get_need_infer_audios(self, group_req_indexes: GroupReqIndexes) -> List[Audi return audios_need_infer - async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + async def handle_group_indexes(self, group_req_indexes: GenerateReqIndex): audios_need_infer = self.get_need_infer_audios(group_req_indexes) if len(audios_need_infer) == 0: @@ -153,13 +153,16 @@ async def infer_audios(self, dp_index: int, audios, events): async def loop_for_netio_req(self): try: while True: - recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) + if isinstance(recv_req, GenerateReqIndex): logger.info( f"audio recv req id {recv_req.group_req_id} " f"audio count {len(recv_req.multimodal_params.audios)}" ) asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + elif isinstance(recv_req, BaseReq): + # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理 + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) else: assert False, f"Error Req Inf {recv_req}" except Exception as e: diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index c9b806c47d..10386b70e6 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd +from .group_req import AbortedReqCmd, StopStrMatchedReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..d644c0c316 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -1,33 +1,10 @@ from dataclasses import dataclass from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.core.objs.sampling_params import SamplingParams from typing import List from ..req import Req -@dataclass -class GroupReqIndexes: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_indexes: List[int] - time_mark: float - - -@dataclass -class GroupReqObjs: - group_req_id: int - multimodal_params: MultimodalParams - shm_req_objs: List[Req] - time_mark: float - - def to_group_req_index(self): - return GroupReqIndexes( - group_req_id=self.group_req_id, - multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], - time_mark=self.time_mark, - ) - - @dataclass class AbortedReqCmd: req_id: int diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index ea99dae5f6..8019c9a1a1 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -4,6 +4,9 @@ LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1280)) LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class QueueItem(ctypes.Structure): @@ -24,9 +27,19 @@ def __init__(self): def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int): str_bytes = token_str.encode("utf-8") - assert ( - len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES - ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes." + max_data_len = max(LIGHTLLM_TOKEN_MAX_BYTES - 1, 0) + if len(str_bytes) > max_data_len: + logger.error( + "Token string exceeds max bytes: bytes=%d limit=%d src_index=%d count_output_tokens=%d preview=%s", + len(str_bytes), + max_data_len, + src_index, + count_output_tokens, + token_str, + ) + str_bytes = str_bytes[:max_data_len] + # Ensure truncation never leaves an incomplete UTF-8 sequence. + str_bytes = str_bytes.decode("utf-8", errors="ignore").encode("utf-8") ctypes.memmove(self.data, str_bytes, len(str_bytes)) self.data_len = len(str_bytes) self.src_index = src_index diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index bf7051abdc..4489ccd708 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -54,6 +54,8 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, # move kv to deocde node, only used in pd mode @@ -89,6 +91,7 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index @@ -111,13 +114,18 @@ def __init__( def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() - cls._do_sample = generation_cfg.get("do_sample", False) - cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) - cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) - cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) - cls._temperature = generation_cfg.get("temperature", 1.0) - cls._top_p = generation_cfg.get("top_p", 1.0) - cls._top_k = generation_cfg.get("top_k", -1) + + def _cfg(key, default): + v = generation_cfg.get(key) + return v if v is not None else default + + cls._do_sample = _cfg("do_sample", False) + cls._presence_penalty = _cfg("presence_penalty", 0.0) + cls._frequency_penalty = _cfg("frequency_penalty", 0.0) + cls._repetition_penalty = _cfg("repetition_penalty", 1.0) + cls._temperature = _cfg("temperature", 1.0) + cls._top_p = _cfg("top_p", 1.0) + cls._top_k = _cfg("top_k", -1) cls._stop_sequences = generation_cfg.get("stop", None) except: pass @@ -269,6 +277,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids + ret["invalid_token_ids"] = self.invalid_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node ret["seed"] = self.seed return ret diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7f2b697091..d954870393 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -1,6 +1,7 @@ import os import math import ctypes +import base64 import numpy as np import time from .sampling_params import SamplingParams @@ -14,6 +15,7 @@ from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union from lightllm.utils.log_utils import init_logger +from lightllm.utils.shm_utils import create_or_link_shm logger = init_logger(__name__) @@ -25,19 +27,20 @@ class FinishStatus(ctypes.Structure): NO_FINISH = 0 FINISHED_STOP = 1 FINISHED_LENGTH = 2 + FINISHED_ABORTED = 3 def __init__(self, init_state=NO_FINISH): self.status = init_state def set_status(self, new_status): - assert 0 <= new_status <= 2 + assert 0 <= new_status <= 3 self.status = new_status def get_status(self): return self.status def is_finished(self): - return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED def is_stopped(self): return self.status == self.FINISHED_STOP @@ -50,6 +53,8 @@ def get_finish_reason(self): return "stop" elif self.status == self.FINISHED_LENGTH: return "length" + elif self.status == self.FINISHED_ABORTED: + return "abort" return None @@ -125,6 +130,8 @@ class Req(ctypes.Structure): ("token_hash_page_len_list", TokenPageLenList), # 用于保存查找匹配到的可以被复用的cpu cache 页面信息。 ("cpu_cache_match_page_indexes", CpuCachePageList), + # Number of tokens in routing data SHM, written by model worker, read by HTTP server. + ("shm_routing_num_tokens", ctypes.c_int), ] def get_str(self): @@ -182,6 +189,7 @@ def init( self._mtp_step = get_env_start_args().mtp_step self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.shm_routing_num_tokens = 0 self.post_init() @@ -277,6 +285,69 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8): + """Create routing SHM at actual size (on-demand, not pre-allocated). + + Uses smart mode: links if same-sized SHM exists, otherwise creates new. + """ + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.create_shm() + self.shm_routing_num_tokens = num_tokens + return + + def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8): + """Link to routing SHM from the reader side (HTTP server).""" + if num_moe_layers == 0: + return + num_tokens = self.shm_routing_num_tokens + if num_tokens <= 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_tokens, num_moe_layers, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + """Close and unlink routing SHM (on-demand, no longer pooled).""" + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + self.shm_routing_num_tokens = 0 + return + + def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1): + if num_moe_layers == 0 or topk == 0: + return None + if self.shm_routing_num_tokens <= 0: + return None + try: + from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np + + np_dtype = routing_dtype_id_to_np(dtype_id) + if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None: + self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype) + routing_data = self.get_routing_data() + if routing_data is None: + return None + return { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + except Exception as e: + logger.warning(f"Failed to read routing data for req {self.request_id}: {e}") + return None + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() @@ -297,9 +368,8 @@ def can_release(self): ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: - return True - + # if self.is_aborted and can_released_mark and ref_count_ok: + # return True ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 707f8f31b5..28fc57e034 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -293,6 +313,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -304,6 +326,8 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params @@ -334,15 +358,31 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() + # 移除kwargs中为null的参数,避免覆盖默认值 + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.best_of = kwargs.get("best_of", 1) self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) + do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.do_sample = False if do_sample is None else do_sample + + presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty + + frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty + + repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty + + temperature = kwargs.get("temperature", SamplingParams._temperature) + self.temperature = 1.0 if temperature is None else temperature + + top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_p = 1.0 if top_p is None else top_p + + top_k = kwargs.get("top_k", SamplingParams._top_k) + self.top_k = -1 if top_k is None else top_k self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 65535) @@ -394,6 +434,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -410,15 +455,34 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() + # Some checkpoints store null sampling fields in generation_config.json. + # Keep robust numeric defaults instead of propagating None into ctypes fields. cls._do_sample = generation_cfg.get("do_sample", False) + if cls._do_sample is None: + cls._do_sample = False + cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) + if cls._presence_penalty is None: + cls._presence_penalty = 0.0 + cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) + if cls._frequency_penalty is None: + cls._frequency_penalty = 0.0 + cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) if cls._repetition_penalty is None: cls._repetition_penalty = 1.0 cls._temperature = generation_cfg.get("temperature", 1.0) + if cls._temperature is None: + cls._temperature = 1.0 + cls._top_p = generation_cfg.get("top_p", 1.0) + if cls._top_p is None: + cls._top_p = 1.0 + cls._top_k = generation_cfg.get("top_k", -1) + if cls._top_k is None: + cls._top_k = -1 except: pass @@ -485,6 +549,8 @@ def to_dict(self): "image_max_patch_num": self.image_max_patch_num, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), "stop_sequences": self.stop_sequences.to_list(), "best_of": self.best_of, @@ -493,6 +559,7 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index bcc4b3798a..d98b7e612d 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass @@ -9,16 +9,27 @@ class StartArgs: run_mode: str = field( default="normal", metadata={ - "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] + "choices": [ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual_only", + ] }, ) + performance_mode: str = field(default=None, metadata={"choices": ["personal"]}) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) @@ -26,18 +37,34 @@ class StartArgs: afs_image_embed_dir: str = field(default=None) afs_embed_capacity: int = field(default=250000) pd_decode_rpyc_port: int = field(default=None) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") + model_owner: Optional[str] = field(default=None) model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( default=None, - metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, + metadata={ + "choices": [ + "qwen25", + "llama3", + "mistral", + "deepseekv3", + "qwen", + "deepseekv31", + "deepseekv32", + "glm47", + "kimi_k2", + "qwen3_coder", + ] + }, ) reasoning_parser: Optional[str] = field( default=None, @@ -60,20 +87,21 @@ class StartArgs: }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=512) + running_max_req_size: int = field(default=256) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=None) + lightllm_instance_id: int = field(default=0) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) detail_log: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) - router_token_ratio: float = field(default=0.0) + router_token_ratio: float = field(default=None) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) @@ -81,7 +109,7 @@ class StartArgs: disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) disable_vision: Optional[bool] = field(default=None) @@ -103,12 +131,12 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: List[int] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) @@ -126,18 +154,18 @@ class StartArgs: graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) - quant_type: Optional[str] = field(default=None) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) llm_prefill_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) llm_decode_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) vit_att_backend: List[str] = field( - default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) llm_kv_type: str = field( default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} @@ -158,8 +186,6 @@ class StartArgs: "eagle_with_att", "vanilla_no_att", "eagle_no_att", - "qwen3next_vanilla", - "qwen3next_eagle", None, ] }, @@ -172,7 +198,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -187,6 +213,27 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) + # multi_modal + enable_multimodal_audio: bool = field(default=False) + + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + enable_torch_memory_saver: bool = field(default=False) + enable_weight_cpu_backup: bool = field(default=False) + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]}) + enable_torch_fallback: bool = field(default=False) + enable_triton_fallback: bool = field(default=False) + + enable_return_routed_experts: bool = field(default=False) + + weight_version: str = "default" # hybrid attention model (Qwen3Next) linear_att_hash_page_size: int = field(default=512) diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9aa3a8effc..c77379986c 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool: return False def need_detoken(self): - if ( - (not self.req.is_aborted) - and (not self.req.stop_str_matched) - and len(self.output_ids) < self.req.candetoken_out_len - ): + if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len: return True return False @@ -83,8 +79,6 @@ def get_decode_tokens(self): return prefix_tokens, read_tokens def can_set_release_mark(self): - if self.req.is_aborted: - return True if self.req.stop_str_matched: return True if ( diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..17a47dfde6 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -6,7 +6,6 @@ import zmq import inspect from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List from .decode import decode_token @@ -17,6 +16,14 @@ import time from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import ( + BaseReq, + GenerateResp, + FlushCacheResp, + ReleaseMemoryResp, + ResumeMemoryResp, + GeneralModelToHttpRpcRsp, +) logger = init_logger(__name__) @@ -31,9 +38,9 @@ def __init__( self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") - self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") + self.send_to_httpserver = context.socket(zmq.PUSH) + self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} @@ -46,7 +53,7 @@ def _init_get_token_id_to_token_str(self): self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()} return - def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): + def _add_new_group_req_index(self, recv_obj: BaseReq): for req_index in recv_obj.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) req.link_prompt_ids_shm_array() @@ -74,8 +81,10 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_obj, GeneralModelToHttpRpcRsp): + self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + continue self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 @@ -146,7 +155,7 @@ def gen_token_out(self): # 通知 httpserver 进程 if exist_decode: - self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL) self.remove_finished_reqs() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 892e202e2d..30e2db0e53 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import inspect from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -25,16 +26,41 @@ from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE -from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.common.basemodel.routing_manager import get_routing_config_shm from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.server.io_struct import ( + AbortReq, + BaseReq, + FlushCacheReq, + FlushCacheResp, + GenerateReq, + GenerateResp, + GenerateReqMeta, + GenerateReqIndex, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, + InitWeightsUpdateGroupReq, + InitWeightsUpdateGroupRsp, + DestroyWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupRsp, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromDistributedRsp, + UpdateWeightsFromTensorReq, + UpdateWeightsFromTensorRsp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.utils.statics_utils import MovingAverage from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.torch_memory_saver_utils import MemoryTag from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -74,7 +100,7 @@ def __init__( self.multinode_req_manager = context.socket(zmq.PULL) self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}") logger.info( - f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" + f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}" ) self.enable_multimodal = args.enable_multimodal @@ -98,9 +124,8 @@ def __init__( self.shm_req_manager = ShmReqManager() # recv from detokenization - self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") - self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -122,6 +147,18 @@ def __init__( # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark") self.latest_success_infer_time_mark.set_value(int(time.time())) + + # Cache routing config for MoE expert routing data extraction + self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None + + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + + # 交互式请求 event + self.flush_cache_event: Optional[asyncio.Event] = None + self.release_memory_event: Optional[asyncio.Event] = None + self.resume_memory_event: Optional[asyncio.Event] = None + self.async_events_per_func: Dict[str, asyncio.Event] = {} return def _log_stage_timing(self, group_request_id: int, start_time: float, stage: str, **kwargs): @@ -249,18 +286,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def loop_for_request(self): assert self.args.node_rank > 0 while True: - ( - prompt, - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() - results_generator = self.generate(prompt, sampling_params, multimodal_params, None) + req_obj = await self.multinode_req_manager.recv_pyobj() + if req_obj is None: + continue + if isinstance(req_obj, GenerateReqMeta): + self.process_generate_request(req_obj) + elif isinstance(req_obj, AbortReq): + self.process_abort_request(req_obj) + else: + assert False, f"Unknown request type: {type(req_obj)}" + return - async def generate_wrapper(results_generator): - async for _, _, _, _ in results_generator: - pass + def process_generate_request(self, req_meta: GenerateReqMeta): + prompt = req_meta.prompt + sampling_params = req_meta.sampling_params + multimodal_params = req_meta.multimodal_params + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) - asyncio.create_task(generate_wrapper(results_generator)) + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass + + asyncio.create_task(generate_wrapper(results_generator)) + return + + def process_abort_request(self, request: AbortReq): + asyncio.create_task(self.abort_request(request)) return def alloc_req_id(self, sampling_params, is_health_req: bool = False): @@ -313,10 +364,6 @@ async def generate( ) try: - original_multimodal_params = None - if self.is_multinode_tp_master: - original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) self._log_stage_timing( @@ -325,8 +372,20 @@ async def generate( "verify_and_preload_done", ) + # Debug logging for multimodal requests + if multimodal_params and multimodal_params.images: + logger.debug( + f"[MULTIMODAL_DEBUG] req_id={group_request_id}, " + f"num_images={len(multimodal_params.images)}, " + f"max_new_tokens={sampling_params.max_new_tokens}" + ) + # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) + + async with self.is_pause_cond: + await self.is_pause_cond.wait_for(lambda: not self.is_pause) + # encode prompt_ids = await self._encode(prompt, multimodal_params, sampling_params) self._log_stage_timing( @@ -400,18 +459,17 @@ async def generate( "shm_req_init_done", ) - logger.debug( - f"alloc shm_req for req_id {group_request_id}, " - f"shm_req num: {sampling_params.n} details (req_id, index_in_shm_mem): " - f"{[(req_obj.request_id, req_obj.index_in_shm_mem) for req_obj in req_objs]}" + req_status = ReqStatus( + group_request_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, + multimodal_params=multimodal_params, + req_objs=req_objs, + start_time=start_time, ) - - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) self.req_id_to_out_inf[group_request_id] = req_status - await self.transfer_to_next_module_or_node( - prompt, sampling_params, original_multimodal_params, req_status.group_req_objs - ) + await self.transfer_to_next_module_or_node(req_status.group_req_objs) self._log_stage_timing( group_request_id, start_time, @@ -527,7 +585,21 @@ async def _encode( # 这里的校验对多模态不是很充分, to do if all(isinstance(e, int) for e in prompt): - if not self.enable_multimodal and not self.pd_mode.is_D(): + if self.enable_multimodal: + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + if multimodal_params.audios: + assert self.args.enable_multimodal_audio, "audio multimodal not enabled" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + prompt_ids = self.tokenizer.encode( + prompt, + multimodal_params, + add_special_tokens=sampling_params.add_special_tokens, + already_tokenized=True, + ) + return prompt_ids + elif not self.enable_multimodal and not self.pd_mode.is_D(): if all(e < self.vocab_size for e in prompt): return prompt else: @@ -581,45 +653,50 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: async def transfer_to_next_module_or_node( self, - prompt: str, - sampling_params: SamplingParams, - original_multimodal_params: MultimodalParams, - group_req_objs: Optional[GroupReqObjs] = None, + req_obj: Optional["BaseReq"] = None, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. + req_to_next_node = req_obj.get_req_to_next_node() + self.transfer_to_next_node(req_to_next_node) + req_to_next_module = req_obj.get_req_to_next_module() + await self.transfer_to_next_module(req_to_next_module) + return + + def transfer_to_next_node( + self, + req_to_next_node: Optional["BaseReq"] = None, + ): if self.is_multinode_tp_master: for sender in self.multinode_req_manager: sender.send_pyobj( - (prompt, sampling_params, original_multimodal_params), + req_to_next_node, protocol=pickle.HIGHEST_PROTOCOL, ) - - await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, - group_req_objs: Optional[GroupReqObjs] = None, + req_to_next_module: Optional["GenerateReqIndex"] = None, ): if self.pd_mode.is_P_or_NORMAL(): if not self.args.disable_vision: - self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_visual.send_pyobj(req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL) return if not self.args.disable_audio: - self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_audio.send_pyobj(req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL) return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -627,7 +704,7 @@ async def transfer_to_next_module( if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -761,12 +838,24 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_id} not exist") return False - group_req_objs: GroupReqObjs = req_status.group_req_objs + group_req_objs: GenerateReq = req_status.group_req_objs for req in group_req_objs.shm_req_objs: req.is_aborted = True logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def abort_request(self, request: AbortReq): + request_id = request.request_id + abort_all = request.abort_all + if self.is_multinode_tp_master: + self.transfer_to_next_node(req_to_next_node=request) + if request_id is not None and not abort_all: + await self.abort(request_id) + if abort_all: + for group_req_id in list(self.req_id_to_out_inf.keys()): + await self.abort(group_req_id) + pass + async def recycle_resource_loop(self): pre_time_mark = time.time() @@ -789,6 +878,11 @@ async def recycle_resource_loop(self): self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) _is_aborted = False for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") _is_aborted = _is_aborted or req.is_aborted logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}") await self.shm_req_manager.async_put_back_req_obj(req) @@ -829,65 +923,16 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) + recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: - pass + recv_obj = None try: - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue - - token_list = [] - for req in req_status.group_req_objs.shm_req_objs: - req_id = req.request_id - read_token_count = 1 - if req.out_tokens_queue.is_full(): - read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE - - for _ in range(read_token_count): - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "cpu_prompt_cache_len": req.cpu_prompt_cache_len, - "disk_prompt_cache_len": req.disk_prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - finished_token_index = ( - req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index - ) - - if finished_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) - else: - if req.stop_str_matched: - finish_status = FinishStatus(FinishStatus.FINISHED_STOP) - else: - finish_status = FinishStatus(req.finish_status.status) - - token_list.append((req_id, text, metadata, finish_status)) - else: - break + if recv_obj is None or isinstance(recv_obj, GenerateResp): + await self._handle_recv_generate_request(recv_obj) + elif isinstance(recv_obj, GeneralModelToHttpRpcRsp): + await self._handle_recv_general_model_to_http_request(recv_obj) - async with req_status.lock: - req_status.out_token_info_list.extend(token_list) - req_status.event.set() except BaseException as e: logger.exception(str(e)) raise e @@ -895,13 +940,189 @@ async def handle_loop(self): self.recycle_event.set() return + async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta): + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + token_list = [] + for req in req_status.group_req_objs.shm_req_objs: + req_id = req.request_id + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + + if self._routing_shm is not None: + _num_moe = int(self._routing_shm.arr[0]) + _topk = int(self._routing_shm.arr[1]) + _dtype_id = int(self._routing_shm.arr[2]) + if _num_moe > 0: + routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id) + if routing_meta is not None: + metadata["routed_experts"] = routing_meta + + token_list.append((req_id, text, metadata, finish_status)) + else: + break + + async with req_status.lock: + req_status.out_token_info_list.extend(token_list) + req_status.event.set() + + async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp): + assert recv_obj.func_name is not None + event = await self.get_event_for_func(recv_obj.func_name) + event.result = recv_obj + event.set() + return + + async def pause_generation(self): + # 因为请求是从master node转发到slave node的 + # 所以只要master暂停了,slave自然暂停。 + if self.is_pause: + return + async with self.is_pause_cond: + self.is_pause = True + while True: + await self.abort_request(AbortReq(request_id=None, abort_all=True)) + running_req_num = len(list(self.req_id_to_out_inf.keys())) + if running_req_num == 0: + break + await asyncio.sleep(1.0) + + async def continue_generation(self): + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + + async def get_event_for_func(self, func_name: str) -> asyncio.Event: + if func_name not in self.async_events_per_func: + self.async_events_per_func[func_name] = asyncio.Event() + return self.async_events_per_func[func_name] + + async def http_to_model_special_request( + self, request: GeneralHttpToModelRpcReq, timeout: int = 300 + ) -> GeneralModelToHttpRpcRsp: + event = await self.get_event_for_func(request.func_name) + await self.transfer_to_next_module(request) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + ret = event.result + + except asyncio.TimeoutError: + ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name) + except Exception as e: + ret = GeneralModelToHttpRpcRsp( + success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name + ) + return ret + + async def flush_cache(self, request: FlushCacheReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request) + ) + + async def release_memory_occupation(self, request: ReleaseMemoryReq): + assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation" + # 暂停接受请求,除非resume + await self.pause_generation() + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags) + ) + + async def resume_memory_occupation(self, request: ResumeMemoryReq): + ret = await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags) + ) + if ret.success: + await self.continue_generation() + return ret + + async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request) + ) + + async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request) + ) + + async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request) + ) + + async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]: + if request.abort_all_requests: + await self.abort_request(AbortReq(abort_all=True)) + + if request.flush_cache: + await self.flush_cache(FlushCacheReq()) + + return await self.http_to_model_special_request( + GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request) + ) + class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__( + self, + group_request_id: int, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + req_objs: List[Req], + start_time, + ) -> None: self.lock = asyncio.Lock() self.event = asyncio.Event() - self.group_req_objs = GroupReqObjs( + self.group_req_objs = GenerateReq( group_req_id=group_request_id, + prompt=prompt, + sampling_params=sampling_params, multimodal_params=multimodal_params, shm_req_objs=req_objs, time_mark=start_time, diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py new file mode 100644 index 0000000000..e04e8871ce --- /dev/null +++ b/lightllm/server/io_struct.py @@ -0,0 +1,195 @@ +from abc import ABC +from dataclasses import dataclass +from lightllm.server.core.objs.req import Req +from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from typing import List, Optional, Any, Union +from lightllm.utils.torch_memory_saver_utils import MemoryTag + + +@dataclass +class BaseReq(ABC): + def get_req_to_next_node(self): + return self + + def get_req_to_next_module(self): + return self + + +@dataclass +class BaseRsp(ABC): + success: bool + msg: Optional[str] + + +# for next node +@dataclass +class GenerateReqMeta(BaseReq): + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + + +# for next module +@dataclass +class GenerateReqIndex(BaseReq): + group_req_id: int + multimodal_params: MultimodalParams + shm_req_indexes: List[int] + time_mark: float + + +@dataclass +class GenerateReq(BaseReq): + group_req_id: int + prompt: str + sampling_params: SamplingParams + multimodal_params: MultimodalParams + shm_req_objs: List[Req] + time_mark: float + + def get_req_to_next_module(self): + # 已经完成跨节点转发,可以释放图片原始资源 + self.multimodal_params.free() + return GenerateReqIndex( + group_req_id=self.group_req_id, + multimodal_params=self.multimodal_params, + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + time_mark=self.time_mark, + ) + + def get_req_to_next_node(self): + return GenerateReqMeta( + prompt=self.prompt, + sampling_params=self.sampling_params, + multimodal_params=self.multimodal_params, + ) + + +@dataclass +class GenerateResp(BaseReq): + pass + + +@dataclass +class FlushCacheReq(BaseReq): + pass + + +@dataclass +class FlushCacheResp(BaseReq): + success: bool + + +@dataclass +class AbortReq(BaseReq): + # 外部调用传入,等同内部的 group_req_id + request_id: int = None + abort_all: bool = False + + +@dataclass +class ReleaseMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ReleaseMemoryResp(BaseReq): + success: bool + + +@dataclass +class ResumeMemoryReq(BaseReq): + tags: Optional[List[MemoryTag]] = None + + +@dataclass +class ResumeMemoryResp(BaseReq): + success: bool + + +@dataclass +class GeneralHttpToModelRpcReq(BaseReq): + func_name: str + func_args: Optional[Any] = None + + +@dataclass +class GeneralModelToHttpRpcRsp(BaseRsp): + func_name: str + func_rsp: Optional[Any] = None + + +@dataclass +class InitWeightsUpdateGroupReq(BaseReq): + # The master address + master_address: str + # The master port + master_port: int + # The rank offset + rank_offset: int + # The world size + world_size: int + # The group name + group_name: str = "weight_update_group" + # The backend + backend: str = "nccl" + + +@dataclass +class InitWeightsUpdateGroupRsp(BaseRsp): + pass + + +@dataclass +class DestroyWeightsUpdateGroupReq(BaseReq): + group_name: str = "weight_update_group" + + +@dataclass +class DestroyWeightsUpdateGroupRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromDistributedReq(BaseReq): + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + # The group name + group_name: str = "weight_update_group" + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromDistributedRsp(BaseRsp): + pass + + +@dataclass +class UpdateWeightsFromTensorReq(BaseReq): + """Update model weights from tensor input. + + - Tensors are serialized for transmission + - Data is structured in JSON for easy transmission over HTTP + """ + + serialized_named_tensors: List[Union[str, bytes]] + # Optional format specification for loading + load_format: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromTensorRsp(BaseRsp): + pass diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py index 0a7dec0005..b771ac1a01 100644 --- a/lightllm/server/multi_level_kv_cache/manager.py +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -12,7 +12,7 @@ from queue import Queue from typing import List from lightllm.server.core.objs import ShmReqManager, Req, StartArgs -from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.server.io_struct import GenerateReqIndex, BaseReq from lightllm.utils.graceful_utils import graceful_registry from .cpu_cache_client import CpuKvCacheClient from lightllm.utils.log_utils import init_logger @@ -135,7 +135,7 @@ def _disk_cache_match(self, token_hash_list: List[int], all_pages: List[int]) -> self.cpu_cache_client.lock.release() return all_pages, len(new_page_indexes) - def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + def _handle_group_req_multi_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float): """ match cpu cache and disk cache pages """ @@ -218,8 +218,12 @@ def recv_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - assert isinstance(recv_obj, GroupReqIndexes) + recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if not isinstance(recv_obj, GenerateReqIndex): + # RL 等控制类 BaseReq 透传给 router 处理,避免在此被静默丢弃 + if isinstance(recv_obj, BaseReq): + self.send_to_router.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL) + continue recv_objs.append(recv_obj) start_time = recv_obj.time_mark diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 6210628751..b05f34bd99 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -70,9 +70,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -155,9 +157,11 @@ def read(self): assert self._preload_data is not None ans = self._preload_data self._preload_data = None - self._data = None return ans + def free(self): + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid @@ -211,3 +215,10 @@ def to_origin_dict(self): ret["images"] = [i.to_origin_dict() for i in self.images] ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py index f7c099c292..bc81d835b6 100644 --- a/lightllm/server/req_id_generator.py +++ b/lightllm/server/req_id_generator.py @@ -30,7 +30,8 @@ def __init__(self): self.current_id.arr[0] = 0 self.current_id.arr[1] = 0 self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock") - self._wait_all_workers_ready() + if self.args.httpserver_workers > 1: + self._wait_all_workers_ready() logger.info("ReqIDGenerator init finished") def _wait_all_workers_ready(self): diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index 73c6dba54d..082f6bec08 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -467,6 +467,10 @@ def clear_tree_nodes(self): self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]: assert not node.is_big_page_node() iter_node = node diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..955e808600 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -425,6 +425,10 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return + def flush_cache(self): + self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num) + return + def dec_node_ref_counter(self, node: TreeNode): if node is None: return diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f5e0b8df9a..83972a3d34 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -5,6 +5,7 @@ import pickle import inspect import setproctitle +import rpyc asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -17,7 +18,6 @@ from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue from lightllm.server.core.objs.io_objs import ( - GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd, ) @@ -29,6 +29,18 @@ from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.server.io_struct import ( + BaseReq, + GenerateReqIndex, + FlushCacheReq, + FlushCacheResp, + ReleaseMemoryReq, + ReleaseMemoryResp, + ResumeMemoryReq, + ResumeMemoryResp, + GeneralHttpToModelRpcReq, + GeneralModelToHttpRpcRsp, +) from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -356,8 +368,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans = [] if self.running_batch is None: return ans - for req in self.running_batch.reqs: - if req.is_aborted and req._router_aborted is False: + aborted_req_mask = torch.tensor( + [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu" + ) + if self.is_multinode_tp: + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()): + if is_aborted and req._router_aborted is False: req._router_aborted = True ans.append(req) return ans @@ -406,7 +423,7 @@ def get_used_tokens(self, dp_index): else: return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - def _add_req(self, group_req_indexes: GroupReqIndexes): + def _add_req(self, group_req_indexes: BaseReq): req_group = [] for req_index in group_req_indexes.shm_req_indexes: req = self.shm_req_manager.get_req_obj_by_index(req_index) @@ -454,9 +471,22 @@ def _multinode_tp_generate_new_batch(self): dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) req_id_select_mark = [1 for _ in range(len(req_ids))] req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") + # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + aborted_req_mask = torch.tensor( + [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu" + ) + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) back_req_list = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False + if is_aborted and select == 1: + req = new_batch.pop_req(req_id) + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue if select == 0: req = new_batch.pop_req(req_id) back_req_list.append(req) @@ -472,23 +502,28 @@ def _multinode_tp_generate_new_batch(self): else: req_ids = [None for _ in range(req_num)] dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group) - all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list]) + id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list} req_id_select_mark = [] + aborted_req_mask = [] for req_id in req_ids: - req_id_select_mark.append(1 if req_id in all_req_id_set else 0) + req_id_select_mark.append(1 if req_id in id_to_req_obj else 0) + aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False) req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu") dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - select_req_ids = [] - for req_id, select in zip(req_ids, req_id_select_mark.numpy()): - if select == 1: - select_req_ids.append(req_id) - + aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu") + dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group) select_reqs = [] - for req_id in select_req_ids: - for req in self.req_queue.waiting_req_list: - if req.request_id == req_id: - select_reqs.append(req) - + for req_id, select, is_aborted in zip( + req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy() + ): + if select == 1: + req = id_to_req_obj[req_id] + if is_aborted: + self.req_queue.free_aborted_req(req) + self.shm_req_manager.put_back_req_obj(req) + continue + select_reqs.append(req) for req in select_reqs: self.req_queue.waiting_req_list.remove(req) if select_reqs: @@ -509,13 +544,17 @@ async def _recv_new_reqs_and_schedule(self): self.recv_max_count = 64 try: + # 多机tp需要广播给其他node的请求 + special_reqs = [] # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GenerateReqIndex): self._add_req(recv_req) + elif isinstance(recv_req, GeneralHttpToModelRpcReq): + special_reqs.append(recv_req) else: - assert False, f"Error Req Inf {recv_req}" + raise ValueError(f"Unknown request type: {type(recv_req)}") # 当队列中存在较多的请求时,将一次接受的数量上调 self.recv_max_count = min(int(self.recv_max_count * 1.3), 256) @@ -524,6 +563,8 @@ async def _recv_new_reqs_and_schedule(self): # 当队列已经开始清空的时候,将一次接受的数量下调 self.recv_max_count = 64 + await self._process_special_reqs(special_reqs) + if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: @@ -531,6 +572,50 @@ async def _recv_new_reqs_and_schedule(self): self._generate_new_batch() return + async def _process_special_reqs(self, special_reqs: List[BaseReq]): + if self.is_multinode_tp: + special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs) + for req in special_reqs: + assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq" + await self.forward_to_model(req) + + def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]): + req_num = len(reqs) + if self.node_rank == 0: + req_nums = [len(reqs)] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + else: + req_nums = [None] + dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group) + req_num = req_nums[0] + if req_num > 0: + reqs = [None for _ in range(req_num)] + dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group) + return reqs + + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None: + forward_to_model_tasks = [] + for model_rpc_client in self.model_rpc_clients: + forward_to_model_tasks.append(model_rpc_client.forward_to_model(req)) + all_ret = await asyncio.gather(*forward_to_model_tasks) + succes = all(ret.success for ret in all_ret) + ret = all_ret[0] + ret.success = succes + if self.is_multinode_tp: + output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None + dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group) + for res in output_list: + res: GeneralModelToHttpRpcRsp + if not res.success: + ret = res + break + + if self.node_rank == 0: + self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL) + def clean_up(self): return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3a27c082de..b5a68299ca 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -24,6 +24,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel import routing_manager as _routing_mgr logger = init_logger(__name__) @@ -122,6 +123,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + if req.shm_req.shm_routing_num_tokens > 0: + return + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + mgr = _routing_mgr.g_routing_capture_manager + routing_data = mgr.extract_routing_data(mem_indexes) + req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype) + req.shm_req.shm_routing_data.arr[:] = routing_data + req.shm_req.shm_routing_data.detach_shm() + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -444,6 +455,9 @@ def __init__( if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None + # if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample) + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() + # p d mode use params if self.shm_param.move_kv_to_decode_node.exists: self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict() @@ -456,6 +470,11 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + # nixl decode node information if self.shm_param.nixl_params.data_len > 0: self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) @@ -929,6 +948,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + if _routing_mgr.g_routing_capture_manager is not None: + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 048869f860..1357e13462 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -19,7 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.dist_utils import init_distributed_env +from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd @@ -34,8 +34,11 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) +from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer +from lightllm.utils.patch_torch import monkey_patch_torch_reductions +from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata +from lightllm.distributed import dist_group_manager from lightllm.distributed.communication_op import ( - dist_group_manager, all_gather_into_tensor, all_reduce, broadcast, @@ -49,7 +52,16 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from lightllm.common.basemodel import routing_manager as _routing_mgr +from lightllm.utils.torch_memory_saver_utils import MemoryTag from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.server.io_struct import ( + FlushCacheReq, + InitWeightsUpdateGroupReq, + DestroyWeightsUpdateGroupReq, + UpdateWeightsFromDistributedReq, + UpdateWeightsFromTensorReq, +) class ModeBackend: @@ -122,6 +134,8 @@ def init_model(self, kvargs): ) dist_group_manager.create_groups(group_size=group_size) # set the default group + self._model_update_group = {} + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 @@ -315,14 +329,15 @@ def init_mtp_draft_model(self, main_kvargs: dict): os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" - if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: + if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "qwen3next_vanilla"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "qwen3next_eagle"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" for i in range(num_mtp_modules): + # Get MTP model config first to calculate mem_layer_start mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) model_type = mtp_model_cfg.get("model_type", "") mtp_model_kvargs = { @@ -337,7 +352,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "data_type": main_kvargs.get("data_type", "float16"), "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), + "disable_cudagraph": True, # Disable CUDA graphs for MTP draft models "mem_fraction": main_kvargs["mem_fraction"], "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), "quant_type": main_kvargs.get("quant_type", None), @@ -345,6 +360,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): "run_mode": "normal", "main_model": self.model, "mtp_previous_draft_models": self.draft_models.copy(), + "mtp_index": i, } # Select MTP model class based on model type @@ -367,6 +383,191 @@ def init_mtp_draft_model(self, main_kvargs: dict): self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return + def flush_cache(self, request: FlushCacheReq): + if self.radix_cache is not None: + self.radix_cache.flush_cache() + return True, "Succeeded to flush cache." + + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.req_manager.free_all() + self.model.mem_manager.free_all() + self.model.release_memory_occupation(tags) + self.flush_cache(request=None) + return True, "Succeeded to release memory occupation." + except Exception as e: + self.logger.error(f"release memory occupation failed: {str(e)}") + return False, f"release memory occupation failed: {str(e)}" + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.model.resume_memory_occupation(tags) + return True, "Succeeded to resume memory occupation." + except Exception as e: + self.logger.error(f"resume memory occupation failed: {str(e)}") + return False, f"resume memory occupation failed: {str(e)}" + + def init_weights_update_group(self, request: InitWeightsUpdateGroupReq): + assert torch.distributed.is_initialized(), "Default torch process group must be initialized" + + assert request.group_name != "", "Group name cannot be empty" + rank_offset = request.rank_offset + rank = rank_offset + self.rank_in_dp + world_size = request.world_size + group_name = request.group_name + self.logger.info( + f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, " + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, " + f" backend={request.backend}" + ) + + try: + if group_name in self._model_update_group: + raise ValueError(f"Process group with name {group_name} already exists.") + + self._model_update_group[group_name] = init_custom_process_group( + backend=request.backend, + init_method=f"tcp://{request.master_address}:{request.master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + return True, "Succeeded to initialize custom process group." + + except Exception as e: + message = f"Failed to initialize custom process group: {e}." + self.logger.error(message) + return False, message + + def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq): + try: + if request.group_name in self._model_update_group: + pg = self._model_update_group.pop(request.group_name) + torch.distributed.destroy_process_group(pg) + return True, "Succeeded to destroy custom process group." + else: + return False, "The group to be destroyed does not exist." + except Exception as e: + message = f"Failed to destroy custom process group: {e}." + self.logger.error(message) + return False, message + + def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq): + """ + Update specific parameter in the model weights online + through `_model_update_group` process group. + + Args: + name: the name of the parameter to be updated. + dtype: the data type of the parameter to be updated. + shape: the shape of the parameter to be updated. + """ + + assert request.group_name in self._model_update_group, ( + f"Group {request.group_name} not in {list(self._model_update_group.keys())}. " + "Please call `init_weights_update_group` first." + ) + + try: + weights = [] + handles = [] + for name, dtype, shape in zip(request.names, request.dtypes, request.shapes): + target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) + weight = torch.empty(shape, dtype=target_dtype, device="cuda") + handles.append( + torch.distributed.broadcast( + weight, + src=0, + group=self._model_update_group[request.group_name], + async_op=True, + ) + ) + weights.append((name, weight)) + for handle in handles: + handle.wait() + + self.model.load_weights(weights) + return True, "Succeeded to update parameter online from distributed." + + except Exception as e: + error_msg = ( + f"Failed to update parameter online: {e}. " + f"The full weights of the ModelRunner are partially updated. " + f"Please discard the whole weights." + ) + self.logger.error(error_msg) + return False, error_msg + + def _update_weights_from_flattened_bucket( + self, + flattened_tensor_bucket_dict, + ): + """Handle flattened bucket format for weight updates""" + flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"] + metadata = flattened_tensor_bucket_dict["metadata"] + + # Convert metadata dict to our format + converted_metadata = [] + for meta in metadata: + converted_meta = FlattenedTensorMetadata( + name=meta.name, + shape=meta.shape, + dtype=meta.dtype, + start_idx=meta.start_idx, + end_idx=meta.end_idx, + numel=meta.numel, + ) + converted_metadata.append(converted_meta) + + # Create bucket and reconstruct tensors + bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata) + reconstructed_tensors = bucket.reconstruct_tensors() + + named_tensors = {name: tensor for name, tensor in reconstructed_tensors} + + # Load the reconstructed tensors using the standard method + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from flattened bucket tensor." + + def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq): + try: + monkey_patch_torch_reductions() + if request.load_format == "flattened_bucket": + # Handle flattened bucket format + serialized_named_tensors = MultiprocessingSerializer.deserialize( + request.serialized_named_tensors[self.rank_in_dp] + ) + return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors) + + # We need to get device after patch otherwise the device would be wrong + self.device_module = torch.get_device_module("cuda") + infered_device = self.device_module.current_device() + + named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp]) + + def _unwrap_tensor(tensor, tp_rank, device): + if isinstance(tensor, LocalSerializedTensor): + tensor = tensor.get(tp_rank) + clone = tensor.to(device).clone() + del tensor # free the ipc tensor + return clone + + named_tensors = { + name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device) + for name, tensor in named_tensors + } + + self.model.load_weights(named_tensors) + + return True, "Succeeded to update parameter online from tensor." + + except Exception as e: + message = f"Failed to update parameter online from tensor. Reason: {e}." + self.logger.error(message) + + return False, message + def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor): """ 这个函数会把next token id和logprobs保存到pinned memory中 @@ -617,7 +818,7 @@ def _get_classed_reqs( paused_reqs.append(req_obj) continue - if req_obj.infer_aborted or req_obj.finish_status.is_finished(): + if req_obj.finish_status.is_finished(): if support_overlap: # 延迟处理 req_obj.filter_mark = True @@ -827,6 +1028,18 @@ def _sample_and_scatter_token( ) return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu + def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Scatter captured routing data from capture buffer to KV-indexed GPU buffer. + + Must be called AFTER model.forward() completes. mem_indexes should be the + original (unpadded) tensor — either CPU or CUDA. + """ + if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None: + if not mem_indexes.is_cuda: + mem_indexes = mem_indexes.cuda(non_blocking=True) + num_tokens = mem_indexes.shape[0] + _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index) + def _dp_all_gather_prefill_and_decode_req_num( self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq] ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6c..e068c00c76 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -140,6 +141,7 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + # 第四阶段 event_pack.notify_pre_post_handle() return @@ -152,6 +154,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -190,6 +193,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -244,6 +248,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] @@ -266,6 +271,7 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..ebc55b7ef4 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq ) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) logits = model_output.logits batch_idx, run_reqs = self._diverse_copy( diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c83e8cd4a5..c9484dba6f 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -151,6 +151,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -198,6 +199,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -246,6 +248,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -319,6 +323,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -373,6 +379,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -438,6 +445,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + self._flush_routing_to_kv_buffer(model_input.mem_indexes) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -646,6 +654,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -747,8 +757,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0) + self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1) logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 @@ -788,6 +799,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) + # Copy accepted buffer states back to buffer[0] for MTP + # Only copy when accept_len > 1 + mask = mtp_accept_len > 1 + if mask.sum() > 0: + actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]] + src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[ + actual_req_idxes, mtp_accept_len[mask] - 1 + ] + dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0] + if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"): + g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p( + src_buffer_indexes, dst_buffer_indexes + ) + verify_event = torch.cuda.Event() verify_event.record() diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index f3ad03662e..406c889b72 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List, Tuple -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -15,7 +16,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, @@ -63,6 +67,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -152,6 +164,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = True exist_req_use_random_seed = False + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -173,6 +191,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if req_obj.generator is not None: exist_req_use_random_seed = True req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) @@ -183,6 +206,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + if has_invalid_token_ids: + invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) + cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -190,7 +217,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 408b173371..2060753ae6 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -37,6 +37,8 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.torch_memory_saver_utils import MemoryTag +from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp logger = init_logger(__name__) @@ -128,6 +130,35 @@ def exposed_init_model(self, kvargs): def exposed_get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def release_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.release_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"release memory occupation failed: {str(e)}") + return False + + def resume_memory_occupation(self, tags: List[MemoryTag]): + try: + self.backend.resume_memory_occupation(tags) + return True + except BaseException as e: + logger.exception(f"resume memory occupation failed: {str(e)}") + return False + + def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + try: + req = obtain(req) + if self.backend is None or not hasattr(self.backend, req.func_name): + raise ValueError(f"Backend does not support function {req.func_name}") + success, ret = getattr(self.backend, req.func_name)(req.func_args) + return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret) + except BaseException as e: + logger.exception(f"forward to model backend failed: {str(e)}") + return GeneralModelToHttpRpcRsp( + success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name + ) + class ModelRpcClient: def __init__(self, conn): @@ -151,6 +182,7 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.conn.root.init_model) self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num) + self._forward_to_model = async_wrap(self.conn.root.forward_to_model) return async def init_model(self, kvargs): @@ -162,6 +194,10 @@ async def get_max_total_token_num(self): ans = self._get_max_total_token_num() return obtain(await ans) + async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp: + ans = self._forward_to_model(req) + return obtain(await ans) + def _init_env( args, @@ -222,7 +258,11 @@ async def start_model_process( success_event, ), ) - proc.start() + from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper + + torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver) + with torch_memory_saver.configure_subprocess(): + proc.start() # Use asyncio.to_thread to make the blocking wait non-blocking await asyncio.to_thread(success_event.wait, timeout=40) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 73113a59b8..e1e2479c86 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -33,6 +33,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req): req.cpu_cache_match_page_indexes.clear() self.router.cpu_cache_client.lock.release() + def free_aborted_req(self, req: Req): + # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串 + input_len = req.input_len + req.link_prompt_ids_shm_array() + req.link_logprobs_shm_array() + req.candetoken_out_len = 1 + req.finish_token_index = input_len + req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0] + req.shm_logprobs.arr[input_len] = 0 + req.finish_status.set_status(FinishStatus.FINISHED_ABORTED) + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..24f017d95c 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -84,8 +84,8 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.is_aborted and not self.router.is_multinode_tp: + # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 abort_req_list.append(req) @@ -104,6 +104,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index 3b831c92a6..0b17bbd1c0 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -75,7 +75,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list for req in waiting_queue: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 @@ -95,6 +95,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index 4c2ebf7c00..1bfb8fc59e 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -41,7 +41,7 @@ def generate_new_batch(self, current_batch: Batch): abort_req_list = [] aborted_count = 0 for req in self.waiting_req_list: - if req.is_aborted: + if req.is_aborted and not self.router.is_multinode_tp: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 aborted_count += 1 @@ -58,6 +58,7 @@ def generate_new_batch(self, current_batch: Batch): req: Req = req logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}") self.free_aborted_req_cpu_cache_pages(req) + self.free_aborted_req(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index a73823b8b7..e5f731df5f 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.reqs_waiting_for_dp_index: List[List[Req]] = [] return + def free_aborted_req(self, req: Req): + dp_index = req.sample_params.suggested_dp_index + assert dp_index >= 0 and dp_index < self.dp_size_in_node + self.inner_queues[dp_index].free_aborted_req(req) + return + def get_dp_queue(self, dp_index: int): assert dp_index < self.dp_size_in_node, "dp index out of range" return self.inner_queues[dp_index] diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf681..6b013c4c63 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -10,7 +10,6 @@ import threading import collections from typing import List -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -21,6 +20,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.io_struct import BaseReq, GenerateReqIndex from rpyc.utils.classic import obtain @@ -91,7 +91,7 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[ImageItem]: + def get_need_infer_images(self, group_req_indexes: GenerateReqIndex) -> List[ImageItem]: shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted disable_prompt_cache = shm_req.sample_params.disable_prompt_cache @@ -121,7 +121,7 @@ def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[Imag return images_need_infer - async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + async def handle_group_indexes(self, group_req_indexes: GenerateReqIndex): images_need_infer = self.get_need_infer_images(group_req_indexes) if len(images_need_infer) == 0: @@ -174,13 +174,16 @@ async def infer_images(self, dp_index: int, images, events): async def loop_for_netio_req(self): try: while True: - recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) - if isinstance(recv_req, GroupReqIndexes): + recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) + if isinstance(recv_req, GenerateReqIndex): logger.info( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + elif isinstance(recv_req, BaseReq): + # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理 + self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL) else: assert False, f"Error Req Inf {recv_req}" except Exception as e: diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 5b9705ed0e..87aae86e4f 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -80,12 +80,15 @@ def init_vision_distributed_env(kvargs): device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) + # 不要在init_process_group时,显示的传入device_id + # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank + # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行 + # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。 dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -150,7 +153,6 @@ def init_distributed_env(kvargs): init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), ) # warmup nccl communicator _a = torch.zeros([1]).to(f"cuda:{device_id}") @@ -316,3 +318,71 @@ def _init_nccl_env(): assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}" return + + +# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675 +def init_custom_process_group( + backend=None, + init_method=None, + timeout=None, + world_size=-1, + rank=-1, + store=None, + group_name=None, + pg_options=None, + device_id=None, +): + from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, + ) + + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 + # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 + # We need to determine the appropriate parameter name based on PyTorch version + pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + device_id=device_id, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index b87096d945..c3a466191d 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,45 +1,92 @@ import socket import subprocess import ipaddress -import random +import os from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +DEFAULT_BASE_PORT = 10000 +PORTS_PER_INSTANCE = 1000 +MAX_INSTANCE_ID = 7 -def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000): - port_list = [] - for port in range(from_port_num, 65536): + +def _is_port_available(port: int) -> bool: + """Check if a port is available by attempting to bind it.""" + try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_ports: - port_list.append(port) - if len(port_list) > num * 30: - break + s.bind(("", port)) + return True + except OSError: + return False - if len(port_list) < num: - return None - random.shuffle(port_list) - return port_list[0:num] +def alloc_can_use_network_port(num=3, instance_id=0, used_ports=None): + """ + Allocate available network ports within an instance-specific range. + + Each instance gets a dedicated 1000-port range starting from BASE_PORT + (default 10000, override via LIGHTLLM_BASE_PORT env var). + Instance 0: 10000-10999, Instance 1: 11000-11999, etc. + """ + if instance_id < 0 or instance_id > MAX_INSTANCE_ID: + raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}") + + base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", DEFAULT_BASE_PORT)) + range_start = base_port + instance_id * PORTS_PER_INSTANCE + range_end = range_start + PORTS_PER_INSTANCE + used_set = set(used_ports) if used_ports else set() + + port_list = [] + for port in range(range_start, range_end): + if len(port_list) >= num: + break + if port in used_set: + continue + if _is_port_available(port): + port_list.append(port) + used_set.add(port) + + if len(port_list) >= num: + logger.info( + f"Instance {instance_id}: allocated {len(port_list)} ports in [{range_start}, {range_end}): {port_list}" + ) + return port_list + + raise RuntimeError( + f"Failed to allocate {num} ports for instance {instance_id} in range [{range_start}, {range_end}). " + f"Only found {len(port_list)} available. Try a different instance_id or set LIGHTLLM_BASE_PORT." + ) def alloc_can_use_port(min_port, max_port): port_list = [] for port in range(min_port, max_port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: port_list.append(port) + except Exception: + continue return port_list def find_available_port(start_port, end_port): for port in range(start_port, end_port + 1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - result = sock.connect_ex(("localhost", port)) + try: + test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + result = test_socket.connect_ex(("localhost", port)) + test_socket.close() + if result != 0: return port + except Exception: + continue return None diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py new file mode 100644 index 0000000000..9f51edeb64 --- /dev/null +++ b/lightllm/utils/patch_torch.py @@ -0,0 +1,63 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py +from typing import Callable, Union + +import torch +from packaging import version +from torch.multiprocessing import reductions + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + # Currently, NPU does not support UUID. This has been temporarily commented out, + # with support expected in the fourth quarter. + # if _is_npu: + # return + + if hasattr(reductions, "_reduce_tensor_original"): + return + + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + + reductions.init_reductions() + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py new file mode 100644 index 0000000000..d8180aeb0c --- /dev/null +++ b/lightllm/utils/serializer.py @@ -0,0 +1,131 @@ +# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py + +import base64 +import pickle +import io +from dataclasses import dataclass +from multiprocessing.reduction import ForkingPickler +from typing import List + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """ + Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data, validate=True) + + return SafeUnpickler(io.BytesIO(data)).load() + + +class SafeUnpickler(pickle.Unpickler): + ALLOWED_MODULE_PREFIXES = { + # --- Python types --- + "builtins.", + "collections.", + "copyreg.", + "functools.", + "itertools.", + "operator.", + "types.", + "weakref.", + # --- PyTorch types --- + "torch.", + "torch._tensor.", + "torch.storage.", + "torch.nn.parameter.", + "torch.autograd.function.", + # --- torch distributed --- + "torch.distributed.", + "torch.distributed._shard.", + "torch.distributed._composable.", + "torch._C._distributed_c10d.", + "torch._C._distributed_fsdp.", + "torch.distributed.optim.", + # --- multiprocessing --- + "multiprocessing.resource_sharer.", + "multiprocessing.reduction.", + "pickletools.", + # --- PEFT / LoRA --- + "peft.", + "transformers.", + "huggingface_hub.", + # --- SGLang & Unitest --- + "sglang.srt.weight_sync.tensor_bucket.", + "sglang.srt.model_executor.model_runner.", + "sglang.srt.layers.", + "sglang.srt.utils.", + # --- LightLLM --- + "lightllm.utils.", + } + + DENY_CLASSES = { + ("builtins", "eval"), + ("builtins", "exec"), + ("builtins", "compile"), + ("os", "system"), + ("subprocess", "Popen"), + ("subprocess", "run"), + ("codecs", "decode"), + ("types", "CodeType"), + ("types", "FunctionType"), + } + + def find_class(self, module, name): + # Block deterministic attacks + if (module, name) in self.DENY_CLASSES: + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + # Allowlist of safe-to-load modules. + if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES): + return super().find_class(module, name) + + # Block everything else. (Potential attack surface) + raise RuntimeError( + f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164" + ) + + +@dataclass +class LocalSerializedTensor: + """torch.Tensor that gets serialized by MultiprocessingSerializer + (which only serializes a pointer and not the data). + The i-th element in the list corresponds to i-th rank's GPU.""" + + values: List[bytes] + + def get(self, rank: int): + return MultiprocessingSerializer.deserialize(self.values[rank]) diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py new file mode 100644 index 0000000000..a9d7a367dd --- /dev/null +++ b/lightllm/utils/tensor_bucket.py @@ -0,0 +1,104 @@ +# copy from +# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/ +# srt/weight_sync/tensor_bucket.py +from dataclasses import dataclass +from typing import List, Tuple + +import torch + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata") + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py new file mode 100644 index 0000000000..c1184ef30c --- /dev/null +++ b/lightllm/utils/torch_memory_saver_utils.py @@ -0,0 +1,92 @@ +import torch +from contextlib import contextmanager +from enum import Enum +from lightllm.utils.log_utils import init_logger + +try: + from torch_memory_saver import ( + torch_memory_saver, + configure_subprocess, + ) + + HAS_TORCH_MEMORY_SAVER = True + +except ImportError: + HAS_TORCH_MEMORY_SAVER = False + pass + +logger = init_logger(__name__) + + +class MemoryTag(Enum): + KV_CACHE = "kv_cache" + WEIGHT = "weights" + GRAPH = "graph" + + def is_kv_cache(self): + return self == MemoryTag.KV_CACHE + + def is_weight(self): + return self == MemoryTag.WEIGHT + + def is_graph(self): + return self == MemoryTag.GRAPH + + def __str__(self): + return self.value + + +class TorchMemorySaverWrapper: + def __new__(cls, enable_torch_memory_saver: bool = False): + if enable_torch_memory_saver: + assert ( + HAS_TORCH_MEMORY_SAVER + ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`." + return _TorchMemorySaver() + else: + return _TorchMemorySaverFake() + + +class _TorchMemorySaver: + def configure_subprocess(self): + return configure_subprocess() + + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup) + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value) + + def disable(self): + return torch_memory_saver.disable() + + def pause(self, tag: MemoryTag): + return torch_memory_saver.pause(tag=tag.value) + + def resume(self, tag: MemoryTag): + return torch_memory_saver.resume(tag=tag.value) + + +class _TorchMemorySaverFake: + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self, tag: MemoryTag, enable_cpu_backup: bool = False): + yield + + def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs): + return torch.cuda.graph(graph_obj, **kwargs) + + @contextmanager + def disable(self): + yield + + def pause(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, pause is not supported.") + return + + def resume(self, tag: MemoryTag): + logger.warning("torch_memory_saver is not enabled, resume is not supported.") + return diff --git a/requirements.txt b/requirements.txt index d37ae05690..2ede1b24c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,6 +94,7 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +torch_memory_saver==0.0.9 xformers==0.0.33.post2 redis==7.3.0 litellm>=1.52.0,<1.85 diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py new file mode 100644 index 0000000000..85c4e44ef9 --- /dev/null +++ b/test/test_api/test_r3.py @@ -0,0 +1,92 @@ +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France? What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + # "return_routed_experts": True, + # "repetition_penalty": 1.0, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype_str = routing_info["dtype"] + dtype = np.dtype(dtype_str) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape}") + print(f"Dtype: {dtype}") + print(f"Num tokens: {shape[0]}") + print(f"Num MoE layers: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Compute payload size savings + int32_size = np.prod(shape) * 4 + actual_size = len(data) + savings = (1 - actual_size / int32_size) * 100 + print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)") + + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = shape[0] + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}") + + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 0000000000..dcc010b372 --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,219 @@ +import torch +import numpy as np + + +class TestRoutingCaptureManager: + def test_capture_and_extract_basic(self): + """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=64, + ) + + # Simulate a batch of 10 tokens at KV-cache positions [100..109] + mem_indexes = torch.arange(100, 110, device="cuda") + + # Capture routing for each MoE layer (writes to capture buffer) + for layer_idx in range(4): + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0) + + # Flush from capture buffer to KV-indexed gpu_kv_buffer + manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0) + + # Extract for those same KV-cache positions + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (4, 10, 8) + assert result.dtype == np.int8 + + def test_capture_writes_to_correct_kv_positions(self): + """Verify that captured data lands in the right KV-cache positions after flush.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Use non-contiguous mem_indexes to simulate real KV-cache + mem_indexes = torch.tensor([10, 50, 200], device="cuda") + + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + topk_ids_layer1 = topk_ids + 20 + manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0) + + # Flush to KV positions + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + + # Extract and verify + result = manager.extract_from_gpu(mem_indexes) + assert result.shape == (2, 3, 4) + np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy()) + np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy()) + + def test_microbatch_isolation(self): + """Two microbatches writing to different KV positions don't interfere.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=32, + kv_cache_size=256, + max_capture_tokens=16, + ) + + # Microbatch 0: positions [10, 11] + mem0 = torch.tensor([10, 11], device="cuda") + ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Microbatch 1: positions [20, 21] + mem1 = torch.tensor([20, 21], device="cuda") + ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Flush each microbatch to different KV positions + manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0) + manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1) + + # Extract microbatch 0 + result0 = manager.extract_from_gpu(mem0) + assert result0.shape == (1, 2, 4) + assert result0[0, 0, 0] == 1 + + # Extract microbatch 1 + result1 = manager.extract_from_gpu(mem1) + assert result1[0, 0, 0] == 2 + + def test_dtype_selection_int8(self): + """Models with ≤127 experts use int8.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int8 + assert manager.np_dtype == np.int8 + assert manager.dtype_id == 1 + + def test_dtype_selection_int16(self): + """Models with >127 experts use int16.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + kv_cache_size=128, + max_capture_tokens=16, + ) + assert manager.dtype == torch.int16 + assert manager.np_dtype == np.int16 + assert manager.dtype_id == 2 + + def test_extract_preserves_values(self): + """Extracted values exactly match what was captured, no dtype truncation.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=4, + num_experts=64, + kv_cache_size=64, + max_capture_tokens=16, + ) + + mem_indexes = torch.tensor([0, 1, 2], device="cuda") + + topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush then extract + manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0) + result = manager.extract_from_gpu(mem_indexes) + expected = topk_ids.cpu().numpy().astype(np.int8) + np.testing.assert_array_equal(result[0], expected) + + def test_gpu_kv_buffer_shape(self): + """Buffer shape is (num_moe_layers, kv_cache_size, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # 127 experts fits in int8 (max value 127) + manager = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=127, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager.gpu_kv_buffer.shape == (48, 2048, 8) + assert manager.gpu_kv_buffer.dtype == torch.int8 + assert manager.gpu_kv_buffer.device.type == "cuda" + + # 128 experts requires int16 + manager2 = RoutingCaptureManager( + num_moe_layers=48, + topk=8, + num_experts=128, + kv_cache_size=2048, + max_capture_tokens=256, + ) + assert manager2.gpu_kv_buffer.dtype == torch.int16 + + def test_partial_token_capture(self): + """capture() only writes num_tokens rows to the buffer.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=32, + kv_cache_size=128, + max_capture_tokens=16, + ) + + # Capture only 3 tokens, flush to 5 KV positions (first 3 get data) + mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda") + + topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens + manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0) + + # Flush only the 3 captured tokens + manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0) + + # Positions 10-12 should have data, 13-14 should be zeros (from init) + result_written = manager.extract_from_gpu(mem_indexes[:3]) + np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8)) + + result_unwritten = manager.extract_from_gpu(mem_indexes[3:]) + np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8)) + + def test_capture_buffer_shape(self): + """Capture buffer has correct shape (max_tokens, num_moe_layers, topk).""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + kv_cache_size=1024, + max_capture_tokens=256, + ) + assert manager._capture_buffer[0].shape == (256, 4, 8) + assert manager._capture_buffer[1].shape == (256, 4, 8) + assert manager._capture_buffer[0].dtype == torch.int8 diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 0000000000..3b2f159f62 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 605433e9d8..dfeda0b6f7 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -230,5 +230,32 @@ def test_case9(): assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) +def test_case10(): + """ + 测试场景:测试 flush_cache 函数 + """ + print("\nTest Case 10: Testing flush_cache function\n") + tree = RadixCache("unique_name", 100, 0) + tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is not None + assert size == 3 + tree.flush_cache() + tree_node, size, values = tree.match_prefix( + torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node is None + assert size == 0 + assert tree.get_tree_total_tokens_num() == 0 + assert tree.get_refed_tokens_num() == 0 + assert len(tree.root_node.children) == 0 + assert tree.root_node.token_id_key.numel() == 0 + assert tree.root_node.token_mem_index_value.numel() == 0 + assert tree.root_node.ref_counter == 1 + + if __name__ == "__main__": pytest.main()