diff --git a/.gitignore b/.gitignore
index 63408699f4..3fb49db8b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
+CLAUDE.md
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 89d90f3c12..219343e2c6 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
import triton
-from typing import final, List
+from typing import final, List, Optional
from tqdm import tqdm
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
@@ -33,6 +33,10 @@
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend
@@ -91,6 +95,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
@@ -104,19 +109,21 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()
- self._init_weights()
- self._init_req_manager()
- self._init_mem_manager()
+ with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup):
+ self._init_weights()
+ with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
+ self._init_req_manager()
+ self._init_mem_manager()
+ self._init_kv_move_buffer()
+
# 因为类似 qwen3.5 的linear 架构的模型,其 req_manager 会存储运行时使用的大量 linear state
# 这可能会占用大量的显存,所以,req_manger 中保存的 mem_manger 是mem manager 初始化后再赋值
self.req_manager.mem_manager = self.mem_manager
-
- self._init_kv_move_buffer()
self._check_mem_size()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
- self._load_hf_weights()
+ self.load_weights(self.weight_dict)
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()
@@ -181,17 +188,15 @@ def _init_weights(self, start_layer_index=0):
]
return
- def _load_hf_weights(self):
+ def load_weights(self, weight_dict: dict):
+ assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None"
load_hf_weights(
self.data_type,
- weight_dir=self.weight_dir_,
+ self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
- weight_dict=self.weight_dict,
+ weight_dict=weight_dict,
)
- self.pre_post_weight.verify_load()
- [weight.verify_load() for weight in self.trans_layers_weight]
- return
def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
@@ -999,6 +1004,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
+ torch.cuda.empty_cache()
return
def autotune_layers(self):
@@ -1133,6 +1139,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
+ del b_mtp_index
+ del b_prefill_start_loc
+ del b_q_seq_len
torch.cuda.empty_cache()
return
@@ -1153,3 +1162,72 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["mtp_draft_input_hiddens"] = None
return special_model_input
+
+ def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ torch.cuda.synchronize()
+ if tags is None:
+ self.release_all()
+ return
+ if MemoryTag.WEIGHT in tags:
+ self.release_weight()
+ if MemoryTag.KV_CACHE in tags:
+ self.release_kv_cache()
+ if MemoryTag.GRAPH in tags:
+ self.release_graph()
+ return
+
+ def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
+ if tags is None:
+ self.resume_all()
+ return
+ if MemoryTag.WEIGHT in tags:
+ self.resume_weight()
+ if MemoryTag.KV_CACHE in tags:
+ self.resume_kv_cache()
+ if MemoryTag.GRAPH in tags:
+ self.resume_graph()
+ return
+
+ def release_weight(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_kv_cache(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_graph(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def release_all(self):
+ self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
+ self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
+ self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def resume_weight(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
+
+ def resume_kv_cache(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
+
+ def resume_graph(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
+
+ def resume_all(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
+ self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
+ self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py
index 2a099ffc75..348d9216b7 100644
--- a/lightllm/common/basemodel/cuda_graph.py
+++ b/lightllm/common/basemodel/cuda_graph.py
@@ -8,6 +8,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
+from lightllm.utils.torch_memory_saver_utils import (
+ TorchMemorySaverWrapper,
+ MemoryTag,
+)
from .infer_struct import InferStateInfo
@@ -26,6 +30,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int =
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
+ self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
@@ -96,7 +101,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
delattr(infer_state, param_name)
with lightllm_capture_graph(dist_group):
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
@@ -134,7 +139,7 @@ def _capture_decode_overlap(
with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
- with torch.cuda.graph(graph_obj, pool=self.mempool):
+ with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
index 8f54e14a72..6ca48299f0 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py
@@ -33,6 +33,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
+ moe_layer_index: int = 0,
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
@@ -50,6 +51,7 @@ def __init__(
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
+ self.moe_layer_index = moe_layer_index
self._init_config(network_config)
self._init_redundancy_expert_params()
self._init_parallel_params()
@@ -130,6 +132,7 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ microbatch_index: int = 0,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
@@ -145,6 +148,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
+ moe_layer_index=self.moe_layer_index,
+ microbatch_index=microbatch_index,
)
def low_latency_dispatch(
@@ -295,6 +300,7 @@ def _create_weight(self):
device_id=self.device_id_,
num_experts=self.local_n_routed_experts,
)
+ self.w1, self.w3 = w13_param_list
self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0])
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)
@@ -307,7 +313,8 @@ def _get_expert_weight_list(self, weight_pack: WeightPack):
return weight_list
def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]):
-
+ # for merged weights
+ self._load_merge_weight(weights)
# Load each expert with TP slicing
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
@@ -332,6 +339,7 @@ def _load_expert(
w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"
+
row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_weight in weights:
@@ -341,6 +349,19 @@ def _load_expert(
if w2_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx])
+ def _load_merge_weight(self, weights: Dict[str, torch.Tensor]):
+ w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}"
+ w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}"
+ w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}"
+ row_slice_func = self.row_slicer._slice_weight
+ col_slice_func = self.col_slicer._slice_weight
+ if w1_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1)
+ if w2_merge_weight in weights:
+ self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2)
+ if w3_merge_weight in weights:
+ self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3)
+
def _load_expert_scale(
self,
expert_idx: int,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
index 6ed0cef0b4..4ca1605be4 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py
@@ -8,6 +8,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.quantization.quantize_method import QuantizationMethod
from lightllm.utils.log_utils import init_logger
+from lightllm.common.basemodel import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -46,6 +47,7 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
+ moe_layer_index: int = 0,
) -> None:
network_config["norm_topk_prob"] = None
super().__init__(
@@ -62,6 +64,7 @@ def __init__(
num_fused_shared_experts=num_fused_shared_experts,
layer_num=layer_num,
network_config=network_config,
+ moe_layer_index=moe_layer_index,
)
self.hidden_size = network_config["hidden_size"]
@@ -144,10 +147,15 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ microbatch_index: int = 0,
):
topk_weights, topk_ids = self._router(router_logits, top_k)
+ # Rollout router replay
+ if _routing_mgr.g_routing_capture_manager is not None:
+ _routing_mgr.g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
+
w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
index 00587ac185..1c93cb13dc 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py
@@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ moe_layer_index: Optional[int] = None,
+ microbatch_index: int = 0,
) -> torch.Tensor:
pass
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
index d6e923a115..90b525d275 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py
@@ -3,6 +3,7 @@
from lightllm.common.quantization.no_quant import WeightPack
from lightllm.common.quantization.quantize_method import QuantizationMethod
from .base_impl import FuseMoeBaseImpl
+from lightllm.common.basemodel import routing_manager as _routing_mgr
class FuseMoeTriton(FuseMoeBaseImpl):
@@ -125,6 +126,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
+ moe_layer_index: Optional[int] = None,
+ microbatch_index: int = 0,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
@@ -137,6 +140,10 @@ def __call__(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
+
+ if _routing_mgr.g_routing_capture_manager is not None and moe_layer_index is not None:
+ _routing_mgr.g_routing_capture_manager.capture(moe_layer_index, topk_ids, microbatch_index)
+
output = self._fused_experts(
input_tensor=input_tensor,
w13=w13,
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
index ddbf98a866..15f050c14a 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py
@@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
# 默认weight 的shape是 outxin,这也是目前最通用的约定。
-# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
+# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分。
class RowSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1):
super().__init__(tp_rank, tp_world_size, repeat_times)
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
- weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[0])
- return weight[start:end, :]
+ weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-2])
+ return weight[..., start:end, :]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert (
@@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
- weight_scale.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[0])
- return weight_scale[start:end]
+ weight_scale.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-2])
+ return weight_scale[..., start:end, :]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
- weight_zero_point.shape[0] % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[0])
- return weight_zero_point[start:end]
+ weight_zero_point.shape[-2] % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-2])
+ return weight_zero_point[..., start:end, :]
class ColSliceMixin(SliceMixinTpl):
@@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
- weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight.shape[1])
- return weight[:, start:end]
+ weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight.shape[-1])
+ return weight[..., start:end]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_ * self.repeat_times_
@@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_scale.shape[1])
- return weight_scale[:, start:end]
+ ), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_scale.shape[-1])
+ return weight_scale[..., start:end]
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
- weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
- ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
- start, end = self._get_slice_start_end(weight_zero_point.shape[1])
- return weight_zero_point[:, start:end]
+ weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
+ ), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
+ start, end = self._get_slice_start_end(weight_zero_point.shape[-1])
+ return weight_zero_point[..., start:end]
# awq 的量化权重是inxout存储格式,需要定制实现。
diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py
new file mode 100644
index 0000000000..77b611130f
--- /dev/null
+++ b/lightllm/common/basemodel/routing_manager.py
@@ -0,0 +1,191 @@
+import atexit
+import torch
+import numpy as np
+from multiprocessing import shared_memory
+from typing import Optional
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.dist_utils import get_current_rank_in_dp
+from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
+from lightllm.utils.envs_utils import get_unique_server_name
+
+logger = init_logger(__name__)
+
+
+def routing_dtype_id_to_np(dtype_id: int):
+ if dtype_id == 1:
+ return np.uint8
+ elif dtype_id == 2:
+ return np.int16
+ return np.int32
+
+
+def get_routing_config_shm() -> SharedArray:
+ service_name = get_unique_server_name()
+ return SharedArray(f"{service_name}_routing_config", shape=(4,), dtype=np.int32)
+
+
+class RoutingCaptureManager:
+ def __init__(
+ self,
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+ ):
+ self.num_moe_layers = num_moe_layers
+ self.topk = topk
+ self.num_experts = num_experts
+ self.kv_cache_size = kv_cache_size
+
+ self.dtype = torch.uint8 if num_experts <= 255 else torch.int16
+ dtype_bytes = 1 if self.dtype == torch.uint8 else 2
+
+ # Shape: (kv_cache_size, num_moe_layers, topk) — on CPU to save GPU memory.
+ # Written after forward() via flush_to_routing_buffer(), read on request finish.
+ routing_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes
+ self.routing_buffer = torch.zeros(
+ (kv_cache_size, num_moe_layers, topk),
+ dtype=self.dtype,
+ device="cpu",
+ )
+
+ # Capture buffers: simple contiguous tensors written to during forward().
+ capture_buf_size = max_capture_tokens * num_moe_layers * topk * dtype_bytes
+ self._capture_buffer = [
+ torch.zeros((max_capture_tokens, num_moe_layers, topk), dtype=self.dtype, device="cuda") for _ in range(2)
+ ]
+
+ dtype_name = "uint8" if self.dtype == torch.uint8 else "int16"
+ logger.info(
+ f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, "
+ f"routing_buffer(cpu)={routing_buffer_size / 1024 / 1024:.2f}MB, "
+ f"capture_buffer={capture_buf_size / 1024 / 1024:.2f}MB x2, dtype={dtype_name}"
+ )
+
+ @property
+ def np_dtype(self):
+ return np.uint8 if self.dtype == torch.uint8 else np.int16
+
+ @property
+ def dtype_id(self) -> int:
+ return 1 if self.dtype == torch.uint8 else 2
+
+ def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None:
+ num_tokens = topk_ids.shape[0]
+ self._capture_buffer[microbatch_index][:num_tokens, moe_layer_index, :] = topk_ids.to(self.dtype)
+
+ def flush_to_routing_buffer(self, mem_indexes: torch.Tensor, num_tokens: int, microbatch_index: int = 0) -> None:
+ buf = self._capture_buffer[microbatch_index][:num_tokens] # (num_tokens, num_moe_layers, topk)
+ self.routing_buffer[mem_indexes[:num_tokens].cpu(), :, :] = buf.cpu()
+
+ def extract_routing_data(self, mem_indexes: torch.Tensor) -> np.ndarray:
+ cpu_indexes = mem_indexes.cpu() if mem_indexes.is_cuda else mem_indexes
+ return self.routing_buffer[cpu_indexes, :, :].numpy()
+
+
+g_routing_capture_manager: Optional[RoutingCaptureManager] = None
+
+
+def create_routing_capture_manager(
+ num_moe_layers: int,
+ topk: int,
+ num_experts: int,
+ kv_cache_size: int,
+ max_capture_tokens: int,
+) -> None:
+ global g_routing_capture_manager
+ assert g_routing_capture_manager is None, "RoutingCaptureManager already exists"
+ g_routing_capture_manager = RoutingCaptureManager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=kv_cache_size,
+ max_capture_tokens=max_capture_tokens,
+ )
+
+
+def cleanup_routing_shm_pool() -> None:
+ """Unlink all pre-allocated routing SHM segments. Called at server shutdown."""
+ try:
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+ except Exception:
+ return
+
+ service_name = get_unique_server_name()
+
+ for i in range(args.running_max_req_size):
+ name = f"{service_name}_shm_routing_{i}"
+ try:
+ shm = shared_memory.SharedMemory(name=name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+ config_name = f"{service_name}_routing_config"
+ try:
+ shm = shared_memory.SharedMemory(name=config_name)
+ shm.close()
+ shm.unlink()
+ except Exception:
+ pass
+
+
+def init_routing_capture(model, num_moe_layers: int) -> None:
+ dp_rank = get_current_rank_in_dp()
+ logger.info(f"init_routing_capture called: num_moe_layers={num_moe_layers}, dp_rank={dp_rank}")
+ if dp_rank != 0:
+ logger.info(f"Skipping routing capture initialization on dp_rank={dp_rank}")
+ return
+
+ if num_moe_layers == 0:
+ logger.warning(
+ "enable_return_routed_experts is set but no MoE layers found. Routing capture will not be enabled."
+ )
+ return
+
+ num_experts = model.config.get("n_routed_experts", model.config.get("num_experts", 0))
+ topk = model.config.get("num_experts_per_tok", 0)
+ assert num_experts > 0 and topk > 0
+
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ args = get_env_start_args()
+
+ # Capture buffer must fit the max tokens in any single forward call.
+ # For prefill that's batch_max_tokens; for decode it's graph_max_batch_size.
+ batch_max_tokens = args.batch_max_tokens or args.max_req_total_len or 8192
+ max_capture_tokens = max(batch_max_tokens, args.graph_max_batch_size)
+
+ logger.info(
+ f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
+ f"topk={topk}, num_experts={num_experts}, max_capture_tokens={max_capture_tokens}"
+ )
+
+ create_routing_capture_manager(
+ num_moe_layers=num_moe_layers,
+ topk=topk,
+ num_experts=num_experts,
+ kv_cache_size=model.mem_manager.size + 1,
+ max_capture_tokens=max_capture_tokens,
+ )
+
+ mgr = g_routing_capture_manager
+ dtype_id = mgr.dtype_id
+
+ max_req_total_len = args.max_req_total_len
+
+ # Write config to cross-process SHM
+ shm = get_routing_config_shm()
+ shm.arr[0] = num_moe_layers
+ shm.arr[1] = topk
+ shm.arr[2] = dtype_id
+ shm.arr[3] = max_req_total_len
+ logger.info(
+ f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}, "
+ f"dtype_id={dtype_id}, max_tokens={max_req_total_len}"
+ )
+ atexit.register(cleanup_routing_shm_pool)
diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py
new file mode 100644
index 0000000000..353affd8ed
--- /dev/null
+++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py
@@ -0,0 +1,36 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel_apply_invalid_token(
+ Logits,
+ invalid_token_ids,
+ cu_invalid_token_num,
+ stride_logit_b,
+):
+ cur_batch = tl.program_id(0)
+ start_index = tl.load(cu_invalid_token_num + cur_batch)
+ end_index = tl.load(cu_invalid_token_num + cur_batch + 1)
+ for i in range(start_index, end_index):
+ cur_invalid_token_id = tl.load(invalid_token_ids + i)
+ cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id
+ tl.store(cur_logit_ptr, float("-inf"))
+ return
+
+
+def apply_invalid_token_ids(
+ Logits: torch.Tensor,
+ invalid_token_ids: torch.Tensor,
+ cu_invalid_token_num: torch.Tensor,
+):
+ batch_size = Logits.shape[0]
+ grid = (batch_size,)
+ _fwd_kernel_apply_invalid_token[grid](
+ Logits=Logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ stride_logit_b=Logits.stride(0),
+ )
+ return
diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py
similarity index 100%
rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py
rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py
diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py
similarity index 100%
rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py
rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..c75c871c72
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..14026090e6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
new file mode 100644
index 0000000000..939c939523
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 16
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..13ba4ba8e5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ee316f610b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e027701092
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ddda23d257
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.float16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..560ca6c09d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..0713de7996
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..e950ff0954
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.float16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": true,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": true,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..7f479b8382
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..b3051c6584
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "8448": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..fdb3212216
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.float16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8448": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..a94e669353
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..441421fd5d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.0/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.float16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "67584": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..864d1d3f18
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "2048": {
+ "BLOCK_SIZE": 4096,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..bcf56e01f7
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..ba1dc8a75d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=3072,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "3072": {
+ "BLOCK_SIZE": 2048,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..6f109e1c6e
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "5120": {
+ "BLOCK_SIZE": 32768,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..198a196dfb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "2048": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..537c7a90eb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "256": {
+ "BLOCK_SIZE": 512,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..9a6dcb6fbf
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=4096,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "4096": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..df501847ec
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=5120,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "5120": {
+ "BLOCK_SIZE": 1024,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py
index c62a2572ff..60c12e4c07 100644
--- a/lightllm/common/triton_utils/autotuner.py
+++ b/lightllm/common/triton_utils/autotuner.py
@@ -106,14 +106,10 @@ def __init__(
self.configs_gen_func = configs_gen_func
self.kernel_name = kernel_name
- self.cache_dir = os.path.join(
- Path(__file__).parent,
- "autotune_kernel_configs",
- get_triton_version(),
- get_current_device_name(),
- self.kernel_name,
- )
- os.makedirs(self.cache_dir, exist_ok=True)
+ # cache_dir 依赖 get_current_device_name(),后者要求 torch.cuda.is_available()。
+ # 这里 lazy 化,避免 CPU-only 的进程(例如 Ray driver / verl rollout replica
+ # 入口)在 import 时就触发 TypeError。
+ self._cache_dir: Optional[str] = None
self.fn = fn
self.static_key_func = static_key_func
self.run_key_func = run_key_func
@@ -209,6 +205,25 @@ def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
+ @property
+ def cache_dir(self) -> str:
+ if self._cache_dir is None:
+ device_name = get_current_device_name()
+ if device_name is None:
+ raise RuntimeError(
+ f"Autotuner for kernel {self.kernel_name} requires a visible CUDA/MUSA device "
+ f"to resolve its cache directory, but torch.cuda.is_available() is False."
+ )
+ self._cache_dir = os.path.join(
+ Path(__file__).parent,
+ "autotune_kernel_configs",
+ get_triton_version(),
+ device_name,
+ self.kernel_name,
+ )
+ os.makedirs(self._cache_dir, exist_ok=True)
+ return self._cache_dir
+
def _try_load_cache(self, static_key):
if static_key in self.cached_configs:
return False
diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
index fa2dee444f..88c4b1e8ee 100644
--- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
@@ -231,6 +231,7 @@ def _moe_ffn_tp(
use_grouped_topk=self.n_group,
topk_group=self.topk_group,
num_expert_group=self.n_group,
+ microbatch_index=infer_state.microbatch_index,
)
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
@@ -258,6 +259,7 @@ def _moe_ffn_edp(
topk_group=self.topk_group,
num_expert_group=self.n_group,
is_prefill=infer_state.is_prefill,
+ microbatch_index=infer_state.microbatch_index,
)
if self.n_shared_experts is not None:
diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
index 3eb09f9176..bd72035072 100644
--- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
@@ -242,6 +242,9 @@ def _init_moe(self):
# == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。
if self.num_fused_shared_experts == 0:
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
+ first_moe = self.network_config_["first_k_dense_replace"]
+ freq = self.network_config_.get("moe_layer_freq", 1)
+ moe_layer_index = (self.layer_num_ - first_moe) // freq
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
@@ -256,6 +259,7 @@ def _init_moe(self):
num_fused_shared_experts=self.num_fused_shared_experts,
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=moe_layer_index,
)
def _init_ffn(self):
diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py
index e596eed97c..79bd327068 100644
--- a/lightllm/models/deepseek2/model.py
+++ b/lightllm/models/deepseek2/model.py
@@ -6,6 +6,7 @@
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num
from lightllm.distributed.communication_op import dist_group_manager
@@ -49,6 +50,9 @@ def _init_some_value(self):
def _init_custom(self):
self._init_to_get_yarn_rotary()
dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
+ init_routing_capture(self, num_moe_layers)
def _verify_params(self):
return super()._verify_params()
diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
index 4c457fd993..bb9e6140bf 100644
--- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py
@@ -52,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ microbatch_index=infer_state.microbatch_index,
)
hidden_states = hidden_states.view(num_tokens, hidden_dim)
return self._tpsp_reduce(input=hidden_states, infer_state=infer_state)
diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
index 7c8c30940e..7278c62fec 100644
--- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py
@@ -55,6 +55,7 @@ def _init_moe(self):
num_fused_shared_experts=0,
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=self.layer_num_,
)
def _init_weight_names(self):
diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py
index 9e9561eb24..cff748933d 100644
--- a/lightllm/models/gpt_oss/model.py
+++ b/lightllm/models/gpt_oss/model.py
@@ -2,6 +2,7 @@
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.models.registry import ModelRegistry
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.attention import get_prefill_att_backend_class, get_decode_att_backend_class
@@ -21,6 +22,12 @@ class GptOssTpPartModel(LlamaTpPartModel):
def __init__(self, kvargs):
super().__init__(kvargs)
+ def _init_custom(self):
+ super()._init_custom()
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = len(self.trans_layers_weight)
+ init_routing_capture(self, num_moe_layers)
+
def _init_att_backend(self):
self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])(
model=self
diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py
deleted file mode 100644
index b0e27ac1de..0000000000
--- a/lightllm/models/mixtral/layer_infer/_custom_ops.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import functools
-import json
-import os
-from typing import Any, Dict, Optional, Tuple
-
-import torch
-import triton
-import triton.language as tl
-from lightllm.utils.log_utils import init_logger
-
-logger = init_logger(__name__)
-
-# Pytorch version
-# Triton version in progress
-def topk_softmax(
- topk_weights,
- topk_ids,
- token_expert_indicies,
- gating_output,
- topk=2,
-):
- scores = torch.softmax(gating_output, dim=-1)
- topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)
- return topk_weights, topk_ids
-
-
-def fused_topk(
- hidden_states: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- renormalize: bool,
- alloc_tensor_func=torch.empty,
-):
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
-
- M, _ = hidden_states.shape
-
- topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device)
- topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
- topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk)
- del token_expert_indicies # Not used. Will be used in the future.
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- return topk_weights, topk_ids
diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
index 0cf651598a..d90c631547 100644
--- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py
@@ -1,9 +1,6 @@
-import os
import torch
-import torch.nn.functional as F
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
-from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk
from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight
@@ -21,25 +18,14 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor
num_tokens, hidden_dim = hidden_states.shape
router_logits = layer_weight.moe_gate.mm(hidden_states)
- topk_weights, topk_ids = fused_topk(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=self.num_experts_per_tok,
+ layer_weight.experts.experts(
+ hidden_states,
+ router_logits=router_logits,
+ top_k=self.num_experts_per_tok,
renormalize=self.renormalize,
- alloc_tensor_func=self.alloc_tensor,
+ use_grouped_topk=False,
+ topk_group=None,
+ num_expert_group=None,
+ microbatch_index=getattr(infer_state, "microbatch_index", 0),
)
- from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl
-
- ffn2_out = fused_experts_impl(
- hidden_states=hidden_states,
- w1=layer_weight.experts.w1[0],
- w2=layer_weight.experts.w2[0],
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_fp8_w8a8=False,
- w1_scale=None,
- w2_scale=None,
- alloc_tensor_func=self.alloc_tensor,
- )
- return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state)
+ return hidden_states.view(num_tokens, hidden_dim)
diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
index 51c62fd4cb..d93cb5fb58 100644
--- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py
@@ -57,4 +57,5 @@ def _init_moe(self):
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=self.layer_num_,
)
diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py
index 3c2d7b4e87..35bf38de58 100644
--- a/lightllm/models/mixtral/model.py
+++ b/lightllm/models/mixtral/model.py
@@ -2,6 +2,7 @@
import numpy as np
from lightllm.models.registry import ModelRegistry
from lightllm.common.basemodel.basemodel import TpPartBaseModel
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
@@ -45,6 +46,9 @@ def _verify_params(self):
def _init_custom(self):
self._init_to_get_rotary()
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = len(self.trans_layers_weight)
+ init_routing_capture(self, num_moe_layers)
return
def _init_mem_manager(self):
diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py
index 237c4ad897..c94135573b 100644
--- a/lightllm/models/qwen2_vl/model.py
+++ b/lightllm/models/qwen2_vl/model.py
@@ -12,6 +12,7 @@
from .vision_process import smart_resize
from lightllm.models.qwen2.model import Qwen2TpPartModel
import os
+from typing import Union, List
# Warp of the origal tokenizer
class QWen2VLTokenizer(BaseMultiModalTokenizer):
@@ -52,9 +53,13 @@ def get_image_token_length(self, img: ImageItem):
def get_audio_token_length(self, audio: AudioItem):
raise NotImplementedError
- def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
-
- origin_ids = self.tokenizer.encode(prompt)
+ def encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams = None, **kwargs):
+ if isinstance(prompt, str):
+ origin_ids = self.tokenizer.encode(prompt)
+ elif isinstance(prompt, list):
+ origin_ids = prompt
+ else:
+ raise ValueError(f"Unsupported prompt type: {type(prompt)}")
#
->
origin_ids = [token for token in origin_ids if token != self.image_token_id]
diff --git a/lightllm/models/qwen3_5_moe/layer_infer/__init__.py b/lightllm/models/qwen3_5_moe/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py
index fe4b1883bd..44425e7e10 100644
--- a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py
@@ -12,7 +12,6 @@ def load_hf_weights(self, weights):
def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int):
layer_prefix = f"model.layers.{layer_num}."
keys = list(weights.keys())
- num_experts = 0
for k in keys:
if not k.startswith(layer_prefix):
@@ -20,21 +19,8 @@ def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_s
if "mlp.experts.gate_up_proj" in k:
fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size]
- num_experts = fused_weight.shape[0]
-
prefix = k.rsplit(".gate_up_proj", 1)[0]
gate_weight = fused_weight[:, :moe_intermediate_size, :]
up_weight = fused_weight[:, moe_intermediate_size:, :]
-
- for expert_idx in range(num_experts):
- weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx]
- weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx]
-
- elif "mlp.experts.down_proj" in k:
- down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size]
- num_experts = down_weight.shape[0]
-
- prefix = k.rsplit(".down_proj", 1)[0]
-
- for expert_idx in range(num_experts):
- weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx]
+ weights[f"{prefix}.gate_proj"] = gate_weight
+ weights[f"{prefix}.up_proj"] = up_weight
diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
index 54e4373652..744ddc9d4f 100644
--- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
@@ -85,6 +85,7 @@ def _moe_ffn_tp(
use_grouped_topk=False,
topk_group=None,
num_expert_group=None,
+ microbatch_index=infer_state.microbatch_index,
)
return hidden_states.view(num_tokens, hidden_dim)
@@ -104,6 +105,7 @@ def _moe_ffn_edp(
topk_group=None,
num_expert_group=None,
is_prefill=infer_state.is_prefill,
+ microbatch_index=infer_state.microbatch_index,
)
ep_output = ep_output.view(token_num, hidden_dim)
diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
index e525cb2d20..5358229949 100644
--- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
@@ -52,6 +52,11 @@ def _init_moe(self):
tp_rank=0,
tp_world_size=1,
)
+ mlp_only = set(self.network_config_.get("mlp_only_layers", []))
+ step = self.network_config_.get("decoder_sparse_step", 1)
+ moe_layer_index = sum(
+ 1 for i in range(self.layer_num_) if self.n_routed_experts > 0 and i not in mlp_only and (i + 1) % step == 0
+ )
self.experts = FusedMoeWeight(
gate_proj_name="gate_proj",
down_proj_name="down_proj",
@@ -65,6 +70,7 @@ def _init_moe(self):
quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
layer_num=self.layer_num_,
network_config=self.network_config_,
+ moe_layer_index=moe_layer_index,
)
def _init_qkv(self):
diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py
index b71d7f4878..506f69e3ff 100644
--- a/lightllm/models/qwen3_moe/model.py
+++ b/lightllm/models/qwen3_moe/model.py
@@ -1,9 +1,14 @@
import torch
from typing import final
from lightllm.models.registry import ModelRegistry
-from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
-from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import (
+ Qwen3MOETransformerLayerInfer,
+)
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import (
+ Qwen3MOETransformerLayerWeight,
+)
from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.common.basemodel.routing_manager import init_routing_capture
from lightllm.utils.log_utils import init_logger
from lightllm.distributed.communication_op import dist_group_manager
@@ -28,3 +33,6 @@ def _init_custom(self):
# Only initialize DeepEP group for MoE models with num_experts
if "num_experts" in self.config and self.config["num_experts"] > 0:
dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])
+ if self.args.enable_return_routed_experts:
+ num_moe_layers = sum(1 for w in self.trans_layers_weight if w.is_moe)
+ init_routing_capture(self, num_moe_layers)
diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py
index 9f83832a7e..b4be10d0d0 100644
--- a/lightllm/models/qwen3_moe_mtp/model.py
+++ b/lightllm/models/qwen3_moe_mtp/model.py
@@ -26,6 +26,7 @@ def _pre_init(self, kvargs: dict):
return
def _init_custom(self):
+ super()._init_custom()
self._cos_cached = self.main_model._cos_cached
self._sin_cached = self.main_model._sin_cached
return
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 70b638715e..871001d0cb 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -1,8 +1,7 @@
import argparse
-def make_argument_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser()
+def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
@@ -261,6 +260,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch"
)
+ parser.add_argument(
+ "--lightllm_instance_id",
+ type=int,
+ default=0,
+ help="Instance ID (0~7) for multi-instance port isolation. Each ID maps to a dedicated port range.",
+ )
parser.add_argument(
"--use_config_server_to_init_nccl",
action="store_true",
@@ -747,6 +752,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
+ parser.add_argument(
+ "--enable_torch_memory_saver",
+ action="store_true",
+ help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""",
+ )
+ parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""")
parser.add_argument(
"--disk_cache_dir",
type=str,
@@ -820,4 +831,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
If the op is not implemented for the platform and the hardware support triton,
it will use triton implementation.""",
)
+ parser.add_argument(
+ "--enable_return_routed_experts",
+ action="store_true",
+ default=False,
+ help="Enable returning routed expert indices for MoE models (R3 feature).",
+ )
return parser
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 4e2c723558..13904e44bd 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -34,7 +34,7 @@
import uuid
from PIL import Image
import multiprocessing as mp
-from typing import AsyncGenerator, Union
+from typing import Any, AsyncGenerator, Union
from typing import Callable
from lightllm.server import TokenLoad
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
@@ -50,6 +50,7 @@
from lightllm.utils.error_utils import ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import ReleaseMemoryReq, ResumeMemoryReq
from dataclasses import dataclass
from .api_openai import chat_completions_impl, completions_impl
@@ -61,6 +62,15 @@
ModelCard,
ModelListResponse,
)
+from .io_struct import (
+ AbortReq,
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+ GeneralModelToHttpRpcRsp,
+)
from .build_prompt import build_prompt, init_tokenizer
logger = init_logger(__name__)
@@ -188,6 +198,22 @@ def get_model_name():
return {"model_name": g_objs.args.model_name}
+@app.get("/get_server_info")
+@app.post("/get_server_info")
+def get_server_info():
+ # 将 StartArgs 转换为字典格式
+ from dataclasses import asdict
+
+ server_info: dict[str, Any] = asdict(g_objs.args)
+ return {**server_info}
+
+
+@app.get("/get_weight_version")
+@app.post("/get_weight_version")
+def get_weight_version():
+ return {"weight_version": g_objs.args.weight_version}
+
+
@app.get("/healthz", summary="Check server health")
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
@@ -389,6 +415,83 @@ async def metrics() -> Response:
return response
+@app.post("/abort_request")
+async def abort_request(request: AbortReq, raw_request: Request):
+ """Abort a request."""
+ try:
+ await g_objs.httpserver_manager.abort_request(request)
+ return Response(status_code=200)
+ except Exception as e:
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+async def handle_request_common(request_obj, handler):
+ try:
+ ret: GeneralModelToHttpRpcRsp = await handler(request_obj)
+ if ret.success:
+ return JSONResponse({"success": ret.success, "message": ret.msg}, status_code=200)
+ else:
+ return create_error_response(HTTPStatus.BAD_REQUEST, ret.msg)
+ except Exception as e:
+ logger.error("handle_request_common (%s) error occurred: %s", str(request_obj), str(e), exc_info=True)
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
+
+
+@app.post("/init_weights_update_group")
+async def init_weights_update_group(request: InitWeightsUpdateGroupReq, raw_request: Request):
+ """Init weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.init_weights_update_group)
+
+
+@app.post("/destroy_weights_update_group")
+async def destroy_weights_update_group(request: DestroyWeightsUpdateGroupReq, raw_request: Request):
+ """Destroy weights update group."""
+ return await handle_request_common(request, g_objs.httpserver_manager.destroy_weights_update_group)
+
+
+@app.post("/update_weights_from_distributed")
+async def update_weights_from_distributed(request: UpdateWeightsFromDistributedReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_distributed)
+
+
+@app.post("/update_weights_from_tensor")
+async def update_weights_from_tensor(request: UpdateWeightsFromTensorReq, raw_request: Request):
+ """Update model parameter from distributed online."""
+ return await handle_request_common(request, g_objs.httpserver_manager.update_weights_from_tensor)
+
+
+@app.post("/flush_cache")
+@app.get("/flush_cache")
+async def flush_cache():
+ """Flush the radix cache."""
+ return await handle_request_common(FlushCacheReq(), g_objs.httpserver_manager.flush_cache)
+
+
+@app.post("/pause_generation")
+async def pause_generation():
+ await g_objs.httpserver_manager.pause_generation()
+ return Response(content="Generation paused successfully.", status_code=200)
+
+
+@app.post("/continue_generation")
+async def continue_generation():
+ await g_objs.httpserver_manager.continue_generation()
+ return Response(content="Generation continued successfully.", status_code=200)
+
+
+@app.get("/release_memory_occupation")
+@app.post("/release_memory_occupation")
+async def release_memory_occupation(request: ReleaseMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.release_memory_occupation)
+
+
+@app.get("/resume_memory_occupation")
+@app.post("/resume_memory_occupation")
+async def resume_memory_occupation(request: ResumeMemoryReq):
+ return await handle_request_common(request, g_objs.httpserver_manager.resume_memory_occupation)
+
+
@app.websocket("/pd_register")
async def register_and_keep_alive(websocket: WebSocket):
await websocket.accept()
diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py
index 39a5808aab..28d57ccdc4 100644
--- a/lightllm/server/api_lightllm.py
+++ b/lightllm/server/api_lightllm.py
@@ -35,6 +35,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
return_details = sample_params_dict.pop("return_details", False)
+ return_routed_experts = sample_params_dict.pop(
+ "return_routed_experts", httpserver_manager.args.enable_return_routed_experts
+ )
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
@@ -53,6 +56,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
prompt_token_ids = None
is_first_metadata = True
input_usage = None
+ routed_experts_data = None
async for sub_req_id, request_output, metadata, finish_status in results_generator:
# when set "--return_all_prompt_logprobs", the first token metadata will contains
# prompt_logprobs and prompt_token_ids
@@ -78,6 +82,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
if finish_status.is_finished():
finish_reason_dict[sub_req_id] = finish_status
+ if "routed_experts" in metadata:
+ routed_experts_data = metadata["routed_experts"]
n = sampling_params.n
sub_ids = list(final_output_dict.keys())[:n]
final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids]
@@ -102,6 +108,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
ret["prompt_logprobs"] = prompt_logprobs
if input_usage is not None:
ret["input_usage"] = input_usage
+ if return_routed_experts and routed_experts_data is not None:
+ ret["routed_experts"] = routed_experts_data
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
@@ -112,6 +120,7 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer
prompt = request_dict.pop("inputs")
sample_params_dict = request_dict["parameters"]
_ = sample_params_dict.pop("return_details", False)
+ _ = sample_params_dict.pop("return_routed_experts", None)
sampling_params = SamplingParams()
sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict)
sampling_params.verify()
diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py
index 6e04d5d47e..5306ecb698 100755
--- a/lightllm/server/api_server.py
+++ b/lightllm/server/api_server.py
@@ -1,11 +1,22 @@
import torch
-from .api_cli import make_argument_parser
+from .api_cli import add_cli_args
+from lightllm.server.core.objs.start_args_type import StartArgs
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
-if __name__ == "__main__":
- torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
- parser = make_argument_parser()
- args = parser.parse_args()
- from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start
+
+def launch_server(args: StartArgs):
+ from .api_start import pd_master_start, normal_or_p_d_start, config_server_start, visual_only_start
+
+ try:
+ # this code will not be ok for settings to fork to subprocess
+ torch.multiprocessing.set_start_method("spawn")
+ except RuntimeError as e:
+ logger.warning(f"Failed to set start method: {e}")
+ except Exception as e:
+ logger.error(f"Failed to set start method: {e}")
+ raise e
if args.run_mode == "pd_master":
pd_master_start(args)
@@ -15,3 +26,13 @@
visual_only_start(args)
else:
normal_or_p_d_start(args)
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser()
+ add_cli_args(parser)
+ args = parser.parse_args()
+
+ launch_server(StartArgs(**vars(args)))
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 6f25e42c88..fe793829e5 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -1,3 +1,4 @@
+import multiprocessing as mp
import os
import sys
import time
@@ -17,6 +18,7 @@
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.redis_utils import start_redis_service
from lightllm.utils.shm_size_check import check_recommended_shm_size
+from lightllm.server.core.objs.start_args_type import StartArgs
from lightllm.utils.config_utils import has_audio_module, has_vision_module, is_linear_att_mixed_model
logger = init_logger(__name__)
@@ -53,9 +55,31 @@ def signal_handler(sig, frame):
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)
+ elif sig == signal.SIGHUP:
+ logger.info("Received SIGHUP (terminal closed), shutting down gracefully...")
+ if http_server_process and http_server_process.poll() is None:
+ http_server_process.send_signal(signal.SIGTERM)
+
+ start_time = time.time()
+ while (time.time() - start_time) < 60:
+ if not is_process_active(http_server_process.pid):
+ logger.info("httpserver exit")
+ break
+ time.sleep(1)
+
+ if time.time() - start_time < 60:
+ logger.info("HTTP server has exited gracefully")
+ else:
+ logger.warning("HTTP server did not exit in time, killing it...")
+ kill_recursive(http_server_process)
+
+ process_manager.terminate_all_processes()
+ logger.info("All processes have been terminated gracefully due to terminal closure.")
+ sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGHUP, signal_handler)
logger.info(f"start process pid {os.getpid()}")
if http_server_process:
@@ -63,12 +87,13 @@ def signal_handler(sig, frame):
return
-def normal_or_p_d_start(args):
- from lightllm.server.core.objs.start_args_type import StartArgs
+def _set_envs_and_config(args: StartArgs):
+ mp.set_start_method("spawn", force=True)
- args: StartArgs = args
- set_unique_server_name(args)
+def _launch_subprocesses(args: StartArgs):
+
+ _set_envs_and_config(args)
if args.enable_mps:
from lightllm.utils.device_utils import enable_mps
@@ -133,12 +158,6 @@ def normal_or_p_d_start(args):
check_recommended_shm_size(args)
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
- # 确保单机上多实列不冲突
- if args.zmq_mode == "ipc:///tmp/":
- zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_"
- args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
- args.zmq_mode = zmq_mode
- logger.info(f"zmq mode head: {args.zmq_mode}")
logger.info(f"use tgi api: {args.use_tgi_api}")
@@ -184,12 +203,16 @@ def normal_or_p_d_start(args):
# mtp params check
if args.mtp_mode is not None:
- assert args.mtp_draft_model_dir is not None
+ if args.mtp_draft_model_dir is None:
+ args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step
assert args.mtp_step > 0
else:
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0
+ # automatically set visual_dp based on visual_tp and tp
+ if args.visual_tp < args.tp and args.tp % args.visual_tp == 0:
+ args.visual_dp = args.tp // args.visual_tp
if args.afs_image_embed_dir is not None:
os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(args.afs_image_embed_dir, 0o777)
@@ -324,6 +347,7 @@ def normal_or_p_d_start(args):
node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
num=10 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp + args.audio_dp,
+ instance_id=args.lightllm_instance_id,
used_ports=already_uesd_ports,
)
logger.info(f"alloced ports: {can_use_ports}")
@@ -358,6 +382,16 @@ def normal_or_p_d_start(args):
args.nccl_port = nccl_port
if args.pd_decode_rpyc_port is None:
args.pd_decode_rpyc_port = pd_decode_rpyc_port
+
+ set_unique_server_name(args)
+
+ # 确保单机上多实列不冲突
+ if args.zmq_mode == "ipc:///tmp/":
+ zmq_mode = f"{args.zmq_mode}_{get_unique_server_name()}_"
+ args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
+ args.zmq_mode = zmq_mode
+ logger.info(f"zmq mode head: {args.zmq_mode}")
+
args.router_port = router_port
args.detokenization_port = detokenization_port
args.http_server_port = http_server_port
@@ -462,6 +496,13 @@ def normal_or_p_d_start(args):
],
)
+ return process_manager
+
+
+def normal_or_p_d_start(args: StartArgs):
+
+ process_manager = _launch_subprocesses(args)
+
# 启动 Hypercorn
command = [
"hypercorn",
@@ -497,7 +538,7 @@ def normal_or_p_d_start(args):
return
-def pd_master_start(args):
+def pd_master_start(args: StartArgs):
set_unique_server_name(args)
if args.run_mode != "pd_master":
return
@@ -514,10 +555,7 @@ def pd_master_start(args):
logger.info(f"all start args:{args}")
can_use_ports = alloc_can_use_network_port(
- num=1,
- used_ports=[
- args.port,
- ],
+ num=1, used_nccl_ports=[args.nccl_port, args.port], instance_id=args.lightllm_instance_id
)
metric_port = can_use_ports[0]
diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py
index efe24c53e3..e78d56fc4f 100644
--- a/lightllm/server/audioserver/manager.py
+++ b/lightllm/server/audioserver/manager.py
@@ -11,11 +11,11 @@
from typing import List
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
-from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
+from lightllm.utils.log_utils import init_logger
+from lightllm.server.io_struct import BaseReq, GenerateReqIndex
from lightllm.server.core.objs import ShmReqManager, StartArgs
from lightllm.server.multimodal_params import AudioItem
from .model_infer import start_model_process, AudioModelRpcClient
-from lightllm.utils.log_utils import init_logger
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
@@ -77,7 +77,7 @@ async def wait_to_model_ready(self):
await asyncio.gather(*init_model_ret)
return
- def get_need_infer_audios(self, group_req_indexes: GroupReqIndexes) -> List[AudioItem]:
+ def get_need_infer_audios(self, group_req_indexes: GenerateReqIndex) -> List[AudioItem]:
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
is_aborted = shm_req.is_aborted
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
@@ -102,7 +102,7 @@ def get_need_infer_audios(self, group_req_indexes: GroupReqIndexes) -> List[Audi
return audios_need_infer
- async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes):
+ async def handle_group_indexes(self, group_req_indexes: GenerateReqIndex):
audios_need_infer = self.get_need_infer_audios(group_req_indexes)
if len(audios_need_infer) == 0:
@@ -153,13 +153,16 @@ async def infer_audios(self, dp_index: int, audios, events):
async def loop_for_netio_req(self):
try:
while True:
- recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj)
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj)
+ if isinstance(recv_req, GenerateReqIndex):
logger.info(
f"audio recv req id {recv_req.group_req_id} "
f"audio count {len(recv_req.multimodal_params.audios)}"
)
asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req))
+ elif isinstance(recv_req, BaseReq):
+ # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理
+ self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL)
else:
assert False, f"Error Req Inf {recv_req}"
except Exception as e:
diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py
index c9b806c47d..10386b70e6 100644
--- a/lightllm/server/core/objs/io_objs/__init__.py
+++ b/lightllm/server/core/objs/io_objs/__init__.py
@@ -1 +1 @@
-from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd
+from .group_req import AbortedReqCmd, StopStrMatchedReqCmd
diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py
index dfcbdd2562..d644c0c316 100644
--- a/lightllm/server/core/objs/io_objs/group_req.py
+++ b/lightllm/server/core/objs/io_objs/group_req.py
@@ -1,33 +1,10 @@
from dataclasses import dataclass
from lightllm.server.multimodal_params import MultimodalParams
+from lightllm.server.core.objs.sampling_params import SamplingParams
from typing import List
from ..req import Req
-@dataclass
-class GroupReqIndexes:
- group_req_id: int
- multimodal_params: MultimodalParams
- shm_req_indexes: List[int]
- time_mark: float
-
-
-@dataclass
-class GroupReqObjs:
- group_req_id: int
- multimodal_params: MultimodalParams
- shm_req_objs: List[Req]
- time_mark: float
-
- def to_group_req_index(self):
- return GroupReqIndexes(
- group_req_id=self.group_req_id,
- multimodal_params=self.multimodal_params,
- shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs],
- time_mark=self.time_mark,
- )
-
-
@dataclass
class AbortedReqCmd:
req_id: int
diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py
index ea99dae5f6..8019c9a1a1 100644
--- a/lightllm/server/core/objs/out_token_circlequeue.py
+++ b/lightllm/server/core/objs/out_token_circlequeue.py
@@ -4,6 +4,9 @@
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1280))
LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8))
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
class QueueItem(ctypes.Structure):
@@ -24,9 +27,19 @@ def __init__(self):
def set(self, token_str: str, src_index: int, special: bool, count_output_tokens: int):
str_bytes = token_str.encode("utf-8")
- assert (
- len(str_bytes) <= LIGHTLLM_TOKEN_MAX_BYTES
- ), f"Token string {len(str_bytes)} exceeds maximum length of {LIGHTLLM_TOKEN_MAX_BYTES} bytes."
+ max_data_len = max(LIGHTLLM_TOKEN_MAX_BYTES - 1, 0)
+ if len(str_bytes) > max_data_len:
+ logger.error(
+ "Token string exceeds max bytes: bytes=%d limit=%d src_index=%d count_output_tokens=%d preview=%s",
+ len(str_bytes),
+ max_data_len,
+ src_index,
+ count_output_tokens,
+ token_str,
+ )
+ str_bytes = str_bytes[:max_data_len]
+ # Ensure truncation never leaves an incomplete UTF-8 sequence.
+ str_bytes = str_bytes.decode("utf-8", errors="ignore").encode("utf-8")
ctypes.memmove(self.data, str_bytes, len(str_bytes))
self.data_len = len(str_bytes)
self.src_index = src_index
diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py
index bf7051abdc..4489ccd708 100644
--- a/lightllm/server/core/objs/py_sampling_params.py
+++ b/lightllm/server/core/objs/py_sampling_params.py
@@ -54,6 +54,8 @@ def __init__(
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
allowed_token_ids: Optional[List[int]] = None,
+ # if provided, the invalid token ids will be ignored during generation
+ invalid_token_ids: Optional[List[int]] = None,
# p d mode used params
group_request_id: Optional[int] = None,
# move kv to deocde node, only used in pd mode
@@ -89,6 +91,7 @@ def __init__(
self.guided_grammar = guided_grammar
self.guided_json = guided_json
self.allowed_token_ids = allowed_token_ids
+ self.invalid_token_ids = invalid_token_ids
self.group_request_id = group_request_id
self.move_kv_to_decode_node = move_kv_to_decode_node
self.suggested_dp_index = suggested_dp_index
@@ -111,13 +114,18 @@ def __init__(
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
- cls._do_sample = generation_cfg.get("do_sample", False)
- cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
- cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
- cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
- cls._temperature = generation_cfg.get("temperature", 1.0)
- cls._top_p = generation_cfg.get("top_p", 1.0)
- cls._top_k = generation_cfg.get("top_k", -1)
+
+ def _cfg(key, default):
+ v = generation_cfg.get(key)
+ return v if v is not None else default
+
+ cls._do_sample = _cfg("do_sample", False)
+ cls._presence_penalty = _cfg("presence_penalty", 0.0)
+ cls._frequency_penalty = _cfg("frequency_penalty", 0.0)
+ cls._repetition_penalty = _cfg("repetition_penalty", 1.0)
+ cls._temperature = _cfg("temperature", 1.0)
+ cls._top_p = _cfg("top_p", 1.0)
+ cls._top_k = _cfg("top_k", -1)
cls._stop_sequences = generation_cfg.get("stop", None)
except:
pass
@@ -269,6 +277,7 @@ def to_dict(self):
ret["guided_grammar"] = self.guided_grammar
ret["guided_json"] = self.guided_json
ret["allowed_token_ids"] = self.allowed_token_ids
+ ret["invalid_token_ids"] = self.invalid_token_ids
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
ret["seed"] = self.seed
return ret
diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py
index 7f2b697091..d954870393 100644
--- a/lightllm/server/core/objs/req.py
+++ b/lightllm/server/core/objs/req.py
@@ -1,6 +1,7 @@
import os
import math
import ctypes
+import base64
import numpy as np
import time
from .sampling_params import SamplingParams
@@ -14,6 +15,7 @@
from lightllm.utils.kv_cache_utils import compute_token_list_hash
from typing import List, Any, Union
from lightllm.utils.log_utils import init_logger
+from lightllm.utils.shm_utils import create_or_link_shm
logger = init_logger(__name__)
@@ -25,19 +27,20 @@ class FinishStatus(ctypes.Structure):
NO_FINISH = 0
FINISHED_STOP = 1
FINISHED_LENGTH = 2
+ FINISHED_ABORTED = 3
def __init__(self, init_state=NO_FINISH):
self.status = init_state
def set_status(self, new_status):
- assert 0 <= new_status <= 2
+ assert 0 <= new_status <= 3
self.status = new_status
def get_status(self):
return self.status
def is_finished(self):
- return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH
+ return self.FINISHED_STOP <= self.status <= self.FINISHED_ABORTED
def is_stopped(self):
return self.status == self.FINISHED_STOP
@@ -50,6 +53,8 @@ def get_finish_reason(self):
return "stop"
elif self.status == self.FINISHED_LENGTH:
return "length"
+ elif self.status == self.FINISHED_ABORTED:
+ return "abort"
return None
@@ -125,6 +130,8 @@ class Req(ctypes.Structure):
("token_hash_page_len_list", TokenPageLenList),
# 用于保存查找匹配到的可以被复用的cpu cache 页面信息。
("cpu_cache_match_page_indexes", CpuCachePageList),
+ # Number of tokens in routing data SHM, written by model worker, read by HTTP server.
+ ("shm_routing_num_tokens", ctypes.c_int),
]
def get_str(self):
@@ -182,6 +189,7 @@ def init(
self._mtp_step = get_env_start_args().mtp_step
self.stop_str_matched = False
self.stop_str_matched_token_index = -1
+ self.shm_routing_num_tokens = 0
self.post_init()
@@ -277,6 +285,69 @@ def link_logprobs_shm_array(self):
self.shm_logprobs.link_shm()
return
+ def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int, np_dtype=np.int8):
+ """Create routing SHM at actual size (on-demand, not pre-allocated).
+
+ Uses smart mode: links if same-sized SHM exists, otherwise creates new.
+ """
+ service_uni_name = get_unique_server_name()
+ name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
+ shape = (num_tokens, num_moe_layers, topk)
+ self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype)
+ self.shm_routing_data.create_shm()
+ self.shm_routing_num_tokens = num_tokens
+ return
+
+ def link_routing_data_shm_array(self, num_moe_layers: int, topk: int, np_dtype=np.int8):
+ """Link to routing SHM from the reader side (HTTP server)."""
+ if num_moe_layers == 0:
+ return
+ num_tokens = self.shm_routing_num_tokens
+ if num_tokens <= 0:
+ return
+ service_uni_name = get_unique_server_name()
+ name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}"
+ shape = (num_tokens, num_moe_layers, topk)
+ self.shm_routing_data = ShmArray(name, shape, dtype=np_dtype)
+ self.shm_routing_data.link_shm()
+ return
+
+ def get_routing_data(self):
+ if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None:
+ return None
+ return self.shm_routing_data.arr
+
+ def close_routing_data_shm_array(self):
+ """Close and unlink routing SHM (on-demand, no longer pooled)."""
+ if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None:
+ self.shm_routing_data.close_shm()
+ self.shm_routing_data = None
+ self.shm_routing_num_tokens = 0
+ return
+
+ def get_routing_metadata(self, num_moe_layers: int, topk: int, dtype_id: int = 1):
+ if num_moe_layers == 0 or topk == 0:
+ return None
+ if self.shm_routing_num_tokens <= 0:
+ return None
+ try:
+ from lightllm.common.basemodel.routing_manager import routing_dtype_id_to_np
+
+ np_dtype = routing_dtype_id_to_np(dtype_id)
+ if not hasattr(self, "shm_routing_data") or self.shm_routing_data is None:
+ self.link_routing_data_shm_array(num_moe_layers, topk, np_dtype=np_dtype)
+ routing_data = self.get_routing_data()
+ if routing_data is None:
+ return None
+ return {
+ "shape": list(routing_data.shape),
+ "dtype": str(routing_data.dtype),
+ "data": base64.b64encode(routing_data.tobytes()).decode("ascii"),
+ }
+ except Exception as e:
+ logger.warning(f"Failed to read routing data for req {self.request_id}: {e}")
+ return None
+
def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()
@@ -297,9 +368,8 @@ def can_release(self):
ref_count_ok = self.ref_count == 1
can_released_mark = self.can_released_mark
- if self.is_aborted and can_released_mark and ref_count_ok:
- return True
-
+ # if self.is_aborted and can_released_mark and ref_count_ok:
+ # return True
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched
if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index 707f8f31b5..28fc57e034 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -17,6 +17,7 @@
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))
+INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10))
class StopSequence(ctypes.Structure):
@@ -205,6 +206,25 @@ def to_list(self):
return list(self.ids[: self.size])
+class InvalidTokenIds(ctypes.Structure):
+ _pack_ = 4
+ _fields_ = [
+ ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH),
+ ("size", ctypes.c_int),
+ ]
+
+ def initialize(self, ids: List[int]):
+ self.size = len(ids)
+ assert (
+ self.size <= INVALID_TOKEN_IDS_MAX_LENGTH
+ ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}."
+ self.ids[: self.size] = ids[:]
+ return
+
+ def to_list(self):
+ return list(self.ids[: self.size])
+
+
class ExponentialDecayLengthPenalty(ctypes.Structure):
_pack_ = 4
_fields_ = [
@@ -293,6 +313,8 @@ class SamplingParams(ctypes.Structure):
("ignore_eos", ctypes.c_bool),
# the max number of image patches to be used in the internvl model, for the test
("image_max_patch_num", ctypes.c_int),
+ ("min_pixels", ctypes.c_int),
+ ("max_pixels", ctypes.c_int),
("max_new_tokens", ctypes.c_int),
("min_new_tokens", ctypes.c_int),
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -304,6 +326,8 @@ class SamplingParams(ctypes.Structure):
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
("allowed_token_ids", AllowedTokenIds),
+ # if provided, the invalid token ids will be ignored during generation
+ ("invalid_token_ids", InvalidTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
("group_request_id", ctypes.c_int64), # p d mode used params
@@ -334,15 +358,31 @@ class SamplingParams(ctypes.Structure):
def init(self, tokenizer, **kwargs):
super().__init__()
+ # 移除kwargs中为null的参数,避免覆盖默认值
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
self.best_of = kwargs.get("best_of", 1)
self.n = kwargs.get("n", self.best_of)
- self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
- self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
- self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
- self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
- self.temperature = kwargs.get("temperature", SamplingParams._temperature)
- self.top_p = kwargs.get("top_p", SamplingParams._top_p)
- self.top_k = kwargs.get("top_k", SamplingParams._top_k)
+ do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
+ self.do_sample = False if do_sample is None else do_sample
+
+ presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
+ self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty
+
+ frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
+ self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty
+
+ repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
+ self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty
+
+ temperature = kwargs.get("temperature", SamplingParams._temperature)
+ self.temperature = 1.0 if temperature is None else temperature
+
+ top_p = kwargs.get("top_p", SamplingParams._top_p)
+ self.top_p = 1.0 if top_p is None else top_p
+
+ top_k = kwargs.get("top_k", SamplingParams._top_k)
+ self.top_k = -1 if top_k is None else top_k
self.ignore_eos = kwargs.get("ignore_eos", False)
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
self.max_new_tokens = kwargs.get("max_new_tokens", 65535)
@@ -394,6 +434,11 @@ def init(self, tokenizer, **kwargs):
self.allowed_token_ids = AllowedTokenIds()
self.allowed_token_ids.initialize(allowed_token_ids)
+ # Initialize invalid_token_ids
+ invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
+ self.invalid_token_ids = InvalidTokenIds()
+ self.invalid_token_ids.initialize(list[int](invalid_token_ids))
+
if self.do_sample is False:
self.temperature = 1.0
self.top_p = 1.0
@@ -410,15 +455,34 @@ def init(self, tokenizer, **kwargs):
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
+ # Some checkpoints store null sampling fields in generation_config.json.
+ # Keep robust numeric defaults instead of propagating None into ctypes fields.
cls._do_sample = generation_cfg.get("do_sample", False)
+ if cls._do_sample is None:
+ cls._do_sample = False
+
cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
+ if cls._presence_penalty is None:
+ cls._presence_penalty = 0.0
+
cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
+ if cls._frequency_penalty is None:
+ cls._frequency_penalty = 0.0
+
cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
if cls._repetition_penalty is None:
cls._repetition_penalty = 1.0
cls._temperature = generation_cfg.get("temperature", 1.0)
+ if cls._temperature is None:
+ cls._temperature = 1.0
+
cls._top_p = generation_cfg.get("top_p", 1.0)
+ if cls._top_p is None:
+ cls._top_p = 1.0
+
cls._top_k = generation_cfg.get("top_k", -1)
+ if cls._top_k is None:
+ cls._top_k = -1
except:
pass
@@ -485,6 +549,8 @@ def to_dict(self):
"image_max_patch_num": self.image_max_patch_num,
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
+ "min_pixels": self.min_pixels,
+ "max_pixels": self.max_pixels,
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),
"stop_sequences": self.stop_sequences.to_list(),
"best_of": self.best_of,
@@ -493,6 +559,7 @@ def to_dict(self):
"guided_grammar": self.guided_grammar.to_str(),
"guided_json": self.guided_json.to_str(),
"allowed_token_ids": self.allowed_token_ids.to_list(),
+ "invalid_token_ids": self.invalid_token_ids.to_list(),
"group_request_id": self.group_request_id,
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),
"skip_special_tokens": self.skip_special_tokens,
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index bcc4b3798a..d98b7e612d 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
-# 只是为了更好的编程提示
+# 服务启动参数
@dataclass
@@ -9,16 +9,27 @@ class StartArgs:
run_mode: str = field(
default="normal",
metadata={
- "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"]
+ "choices": [
+ "normal",
+ "prefill",
+ "decode",
+ "nixl_prefill",
+ "nixl_decode",
+ "pd_master",
+ "config_server",
+ "visual_only",
+ ]
},
)
+ performance_mode: str = field(default=None, metadata={"choices": ["personal"]})
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
+ httpserver_workers: int = field(default=1)
zmq_mode: str = field(
default="ipc:///tmp/",
metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"},
)
- pd_master_ip: str = field(default="127.0.0.1")
+ pd_master_ip: str = field(default="0.0.0.0")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
@@ -26,18 +37,34 @@ class StartArgs:
afs_image_embed_dir: str = field(default=None)
afs_embed_capacity: int = field(default=250000)
pd_decode_rpyc_port: int = field(default=None)
- select_p_d_node_strategy: str = field(default=None)
+ select_p_d_node_strategy: str = field(
+ default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]}
+ )
model_name: str = field(default="default_model_name")
+ model_owner: Optional[str] = field(default=None)
model_dir: Optional[str] = field(default=None)
- tokenizer_mode: str = field(default="slow")
+ tokenizer_mode: str = field(default="fast")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
- eos_id: List[int] = field(default_factory=list)
+ eos_id: Optional[List[int]] = field(default=None)
tool_call_parser: Optional[str] = field(
default=None,
- metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]},
+ metadata={
+ "choices": [
+ "qwen25",
+ "llama3",
+ "mistral",
+ "deepseekv3",
+ "qwen",
+ "deepseekv31",
+ "deepseekv32",
+ "glm47",
+ "kimi_k2",
+ "qwen3_coder",
+ ]
+ },
)
reasoning_parser: Optional[str] = field(
default=None,
@@ -60,20 +87,21 @@ class StartArgs:
},
)
chat_template: Optional[str] = field(default=None)
- running_max_req_size: int = field(default=512)
+ running_max_req_size: int = field(default=256)
tp: int = field(default=1)
dp: int = field(default=1)
nnodes: int = field(default=1)
node_rank: int = field(default=0)
- max_req_total_len: int = field(default=2048 + 1024)
+ max_req_total_len: int = field(default=16384)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=None)
+ lightllm_instance_id: int = field(default=0)
use_config_server_to_init_nccl: bool = field(default=False)
trust_remote_code: bool = field(default=False)
detail_log: bool = field(default=False)
disable_log_stats: bool = field(default=False)
log_stats_interval: int = field(default=10)
- router_token_ratio: float = field(default=0.0)
+ router_token_ratio: float = field(default=None)
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
@@ -81,7 +109,7 @@ class StartArgs:
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
- output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
+ output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
disable_vision: Optional[bool] = field(default=None)
@@ -103,12 +131,12 @@ class StartArgs:
health_monitor: bool = field(default=False)
metric_gateway: Optional[str] = field(default=None)
job_name: str = field(default="lightllm")
- grouping_key: List[str] = field(default_factory=list)
+ grouping_key: List[str] = field(default_factory=lambda: [])
push_interval: int = field(default=10)
visual_node_id: int = field(default=None)
visual_infer_batch_size: int = field(default=None)
visual_send_batch_size: int = field(default=1)
- visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
+ visual_gpu_ids: List[int] = field(default=None)
visual_tp: int = field(default=1)
visual_dp: int = field(default=1)
visual_nccl_ports: List[int] = field(default=None)
@@ -126,18 +154,18 @@ class StartArgs:
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
graph_max_len_in_batch: int = field(default=0)
- quant_type: Optional[str] = field(default=None)
+ quant_type: Optional[str] = field(default="none")
quant_cfg: Optional[str] = field(default=None)
- vit_quant_type: Optional[str] = field(default=None)
+ vit_quant_type: Optional[str] = field(default="none")
vit_quant_cfg: Optional[str] = field(default=None)
llm_prefill_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
)
llm_decode_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
)
vit_att_backend: List[str] = field(
- default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]}
+ default_factory=lambda: ["auto"], metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]}
)
llm_kv_type: str = field(
default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]}
@@ -158,8 +186,6 @@ class StartArgs:
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
- "qwen3next_vanilla",
- "qwen3next_eagle",
None,
]
},
@@ -172,7 +198,7 @@ class StartArgs:
pd_node_id: int = field(default=-1)
enable_cpu_cache: bool = field(default=False)
cpu_cache_storage_size: float = field(default=2)
- cpu_cache_token_page_size: int = field(default=64)
+ cpu_cache_token_page_size: int = field(default=256)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disk_cache_dir: Optional[str] = field(default=None)
@@ -187,6 +213,27 @@ class StartArgs:
metric_port: int = field(default=None)
multinode_httpmanager_port: int = field(default=12345)
multi_level_kv_cache_port: int = field(default=None)
+ # multi_modal
+ enable_multimodal_audio: bool = field(default=False)
+
+ disable_shm_warning: bool = field(default=False)
+ dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]})
+ enable_custom_allgather: bool = field(default=False)
+ enable_fused_shared_experts: bool = field(default=False)
+ enable_mps: bool = field(default=False)
+ multinode_router_gloo_port: int = field(default=20001)
+ schedule_time_interval: float = field(default=0.03)
+ use_dynamic_prompt_cache: bool = field(default=False)
+ disable_custom_allreduce: bool = field(default=False)
+ enable_torch_memory_saver: bool = field(default=False)
+ enable_weight_cpu_backup: bool = field(default=False)
+ hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa"]})
+ enable_torch_fallback: bool = field(default=False)
+ enable_triton_fallback: bool = field(default=False)
+
+ enable_return_routed_experts: bool = field(default=False)
+
+ weight_version: str = "default"
# hybrid attention model (Qwen3Next)
linear_att_hash_page_size: int = field(default=512)
diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py
index 9aa3a8effc..c77379986c 100644
--- a/lightllm/server/detokenization/decode_req.py
+++ b/lightllm/server/detokenization/decode_req.py
@@ -62,11 +62,7 @@ def stop_sequences_str_match(self) -> bool:
return False
def need_detoken(self):
- if (
- (not self.req.is_aborted)
- and (not self.req.stop_str_matched)
- and len(self.output_ids) < self.req.candetoken_out_len
- ):
+ if (not self.req.stop_str_matched) and len(self.output_ids) < self.req.candetoken_out_len:
return True
return False
@@ -83,8 +79,6 @@ def get_decode_tokens(self):
return prefix_tokens, read_tokens
def can_set_release_mark(self):
- if self.req.is_aborted:
- return True
if self.req.stop_str_matched:
return True
if (
diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py
index 389171ba8a..17a47dfde6 100644
--- a/lightllm/server/detokenization/manager.py
+++ b/lightllm/server/detokenization/manager.py
@@ -6,7 +6,6 @@
import zmq
import inspect
from lightllm.server.core.objs import ShmReqManager, StartArgs
-from lightllm.server.core.objs.io_objs import GroupReqIndexes
from lightllm.utils.graceful_utils import graceful_registry
from typing import Union, Dict, List
from .decode import decode_token
@@ -17,6 +16,14 @@
import time
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import (
+ BaseReq,
+ GenerateResp,
+ FlushCacheResp,
+ ReleaseMemoryResp,
+ ResumeMemoryResp,
+ GeneralModelToHttpRpcRsp,
+)
logger = init_logger(__name__)
@@ -31,9 +38,9 @@ def __init__(
self.zmq_recv_socket = context.socket(zmq.PULL)
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}")
- self.pub_to_httpserver = context.socket(zmq.PUB)
- self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
- logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}")
+ self.send_to_httpserver = context.socket(zmq.PUSH)
+ self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
+ logger.info(f"send_to_httpserver sendhwm {self.send_to_httpserver.getsockopt(zmq.SNDHWM)}")
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
self.all_special_ids = set(self.tokenizer.all_special_ids)
self.req_id_to_out: Dict[int, DecodeReq] = {}
@@ -46,7 +53,7 @@ def _init_get_token_id_to_token_str(self):
self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()}
return
- def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
+ def _add_new_group_req_index(self, recv_obj: BaseReq):
for req_index in recv_obj.shm_req_indexes:
req = self.shm_req_manager.get_req_obj_by_index(req_index)
req.link_prompt_ids_shm_array()
@@ -74,8 +81,10 @@ def handle_loop(self):
try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(recv_max_count):
- recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- assert isinstance(recv_obj, GroupReqIndexes)
+ recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if isinstance(recv_obj, GeneralModelToHttpRpcRsp):
+ self.send_to_httpserver.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL)
+ continue
self._add_new_group_req_index(recv_obj=recv_obj)
# 当队列中存在较多的请求时,将一次接受的数量上调
@@ -146,7 +155,7 @@ def gen_token_out(self):
# 通知 httpserver 进程
if exist_decode:
- self.pub_to_httpserver.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_httpserver.send_pyobj(GenerateResp(), protocol=pickle.HIGHEST_PROTOCOL)
self.remove_finished_reqs()
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 892e202e2d..30e2db0e53 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -10,6 +10,7 @@
import hashlib
import datetime
import pickle
+import inspect
from frozendict import frozendict
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -25,16 +26,41 @@
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
-from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
+from lightllm.common.basemodel.routing_manager import get_routing_config_shm
from lightllm.utils.log_utils import init_logger
from lightllm.server.metrics.manager import MetricClient
+from lightllm.server.io_struct import (
+ AbortReq,
+ BaseReq,
+ FlushCacheReq,
+ FlushCacheResp,
+ GenerateReq,
+ GenerateResp,
+ GenerateReqMeta,
+ GenerateReqIndex,
+ ReleaseMemoryReq,
+ ReleaseMemoryResp,
+ ResumeMemoryReq,
+ ResumeMemoryResp,
+ InitWeightsUpdateGroupReq,
+ InitWeightsUpdateGroupRsp,
+ DestroyWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupRsp,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromDistributedRsp,
+ UpdateWeightsFromTensorReq,
+ UpdateWeightsFromTensorRsp,
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+)
from lightllm.utils.statics_utils import MovingAverage
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
from rpyc.utils.classic import obtain
logger = init_logger(__name__)
@@ -74,7 +100,7 @@ def __init__(
self.multinode_req_manager = context.socket(zmq.PULL)
self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}")
logger.info(
- f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}"
+ f"HttpServerManager listening for master node requests on *:{args.multinode_httpmanager_port}"
)
self.enable_multimodal = args.enable_multimodal
@@ -98,9 +124,8 @@ def __init__(
self.shm_req_manager = ShmReqManager()
# recv from detokenization
- self.zmq_recv_socket = context.socket(zmq.SUB)
+ self.zmq_recv_socket = context.socket(zmq.PULL)
self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}")
- self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"")
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
@@ -122,6 +147,18 @@ def __init__(
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
self.latest_success_infer_time_mark.set_value(int(time.time()))
+
+ # Cache routing config for MoE expert routing data extraction
+ self._routing_shm = get_routing_config_shm() if args.enable_return_routed_experts else None
+
+ self.is_pause = False
+ self.is_pause_cond = asyncio.Condition()
+
+ # 交互式请求 event
+ self.flush_cache_event: Optional[asyncio.Event] = None
+ self.release_memory_event: Optional[asyncio.Event] = None
+ self.resume_memory_event: Optional[asyncio.Event] = None
+ self.async_events_per_func: Dict[str, asyncio.Event] = {}
return
def _log_stage_timing(self, group_request_id: int, start_time: float, stage: str, **kwargs):
@@ -249,18 +286,32 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
async def loop_for_request(self):
assert self.args.node_rank > 0
while True:
- (
- prompt,
- sampling_params,
- multimodal_params,
- ) = await self.multinode_req_manager.recv_pyobj()
- results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
+ req_obj = await self.multinode_req_manager.recv_pyobj()
+ if req_obj is None:
+ continue
+ if isinstance(req_obj, GenerateReqMeta):
+ self.process_generate_request(req_obj)
+ elif isinstance(req_obj, AbortReq):
+ self.process_abort_request(req_obj)
+ else:
+ assert False, f"Unknown request type: {type(req_obj)}"
+ return
- async def generate_wrapper(results_generator):
- async for _, _, _, _ in results_generator:
- pass
+ def process_generate_request(self, req_meta: GenerateReqMeta):
+ prompt = req_meta.prompt
+ sampling_params = req_meta.sampling_params
+ multimodal_params = req_meta.multimodal_params
+ results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
- asyncio.create_task(generate_wrapper(results_generator))
+ async def generate_wrapper(results_generator):
+ async for _, _, _, _ in results_generator:
+ pass
+
+ asyncio.create_task(generate_wrapper(results_generator))
+ return
+
+ def process_abort_request(self, request: AbortReq):
+ asyncio.create_task(self.abort_request(request))
return
def alloc_req_id(self, sampling_params, is_health_req: bool = False):
@@ -313,10 +364,6 @@ async def generate(
)
try:
- original_multimodal_params = None
- if self.is_multinode_tp_master:
- original_multimodal_params = copy.deepcopy(multimodal_params)
-
if self.pd_mode.is_P_or_NORMAL():
await multimodal_params.verify_and_preload(request)
self._log_stage_timing(
@@ -325,8 +372,20 @@ async def generate(
"verify_and_preload_done",
)
+ # Debug logging for multimodal requests
+ if multimodal_params and multimodal_params.images:
+ logger.debug(
+ f"[MULTIMODAL_DEBUG] req_id={group_request_id}, "
+ f"num_images={len(multimodal_params.images)}, "
+ f"max_new_tokens={sampling_params.max_new_tokens}"
+ )
+
# 记录请求到达的相关信息
await self._log_req_header(request_headers, group_request_id)
+
+ async with self.is_pause_cond:
+ await self.is_pause_cond.wait_for(lambda: not self.is_pause)
+
# encode
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
self._log_stage_timing(
@@ -400,18 +459,17 @@ async def generate(
"shm_req_init_done",
)
- logger.debug(
- f"alloc shm_req for req_id {group_request_id}, "
- f"shm_req num: {sampling_params.n} details (req_id, index_in_shm_mem): "
- f"{[(req_obj.request_id, req_obj.index_in_shm_mem) for req_obj in req_objs]}"
+ req_status = ReqStatus(
+ group_request_id=group_request_id,
+ prompt=prompt,
+ sampling_params=sampling_params,
+ multimodal_params=multimodal_params,
+ req_objs=req_objs,
+ start_time=start_time,
)
-
- req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time)
self.req_id_to_out_inf[group_request_id] = req_status
- await self.transfer_to_next_module_or_node(
- prompt, sampling_params, original_multimodal_params, req_status.group_req_objs
- )
+ await self.transfer_to_next_module_or_node(req_status.group_req_objs)
self._log_stage_timing(
group_request_id,
start_time,
@@ -527,7 +585,21 @@ async def _encode(
# 这里的校验对多模态不是很充分, to do
if all(isinstance(e, int) for e in prompt):
- if not self.enable_multimodal and not self.pd_mode.is_D():
+ if self.enable_multimodal:
+ assert (
+ len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
+ ), "too many multimodal items!"
+ if multimodal_params.audios:
+ assert self.args.enable_multimodal_audio, "audio multimodal not enabled"
+ await self._alloc_multimodal_resources(multimodal_params, sampling_params)
+ prompt_ids = self.tokenizer.encode(
+ prompt,
+ multimodal_params,
+ add_special_tokens=sampling_params.add_special_tokens,
+ already_tokenized=True,
+ )
+ return prompt_ids
+ elif not self.enable_multimodal and not self.pd_mode.is_D():
if all(e < self.vocab_size for e in prompt):
return prompt
else:
@@ -581,45 +653,50 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
async def transfer_to_next_module_or_node(
self,
- prompt: str,
- sampling_params: SamplingParams,
- original_multimodal_params: MultimodalParams,
- group_req_objs: Optional[GroupReqObjs] = None,
+ req_obj: Optional["BaseReq"] = None,
):
# 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
+ req_to_next_node = req_obj.get_req_to_next_node()
+ self.transfer_to_next_node(req_to_next_node)
+ req_to_next_module = req_obj.get_req_to_next_module()
+ await self.transfer_to_next_module(req_to_next_module)
+ return
+
+ def transfer_to_next_node(
+ self,
+ req_to_next_node: Optional["BaseReq"] = None,
+ ):
if self.is_multinode_tp_master:
for sender in self.multinode_req_manager:
sender.send_pyobj(
- (prompt, sampling_params, original_multimodal_params),
+ req_to_next_node,
protocol=pickle.HIGHEST_PROTOCOL,
)
-
- await self.transfer_to_next_module(group_req_objs)
return
async def transfer_to_next_module(
self,
- group_req_objs: Optional[GroupReqObjs] = None,
+ req_to_next_module: Optional["GenerateReqIndex"] = None,
):
if self.pd_mode.is_P_or_NORMAL():
if not self.args.disable_vision:
- self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_visual.send_pyobj(req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL)
return
if not self.args.disable_audio:
- self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_audio.send_pyobj(req_to_next_module, protocol=pickle.HIGHEST_PROTOCOL)
return
if self.args.enable_cpu_cache:
self.send_to_multi_level_kv_cache.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
self.send_to_router.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
@@ -627,7 +704,7 @@ async def transfer_to_next_module(
if self.pd_mode.is_D():
# 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了
self.send_to_router.send_pyobj(
- group_req_objs.to_group_req_index(),
+ req_to_next_module,
protocol=pickle.HIGHEST_PROTOCOL,
)
return
@@ -761,12 +838,24 @@ async def abort(self, group_req_id: int) -> bool:
logger.warning(f"aborted group_request_id {group_req_id} not exist")
return False
- group_req_objs: GroupReqObjs = req_status.group_req_objs
+ group_req_objs: GenerateReq = req_status.group_req_objs
for req in group_req_objs.shm_req_objs:
req.is_aborted = True
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
return True
+ async def abort_request(self, request: AbortReq):
+ request_id = request.request_id
+ abort_all = request.abort_all
+ if self.is_multinode_tp_master:
+ self.transfer_to_next_node(req_to_next_node=request)
+ if request_id is not None and not abort_all:
+ await self.abort(request_id)
+ if abort_all:
+ for group_req_id in list(self.req_id_to_out_inf.keys()):
+ await self.abort(group_req_id)
+ pass
+
async def recycle_resource_loop(self):
pre_time_mark = time.time()
@@ -789,6 +878,11 @@ async def recycle_resource_loop(self):
self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None)
_is_aborted = False
for req in req_status.group_req_objs.shm_req_objs:
+ if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None:
+ try:
+ req.close_routing_data_shm_array()
+ except Exception as e:
+ logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}")
_is_aborted = _is_aborted or req.is_aborted
logger.debug(f"httpserver release req_id {req.request_id}, index {req.index_in_shm_mem}")
await self.shm_req_manager.async_put_back_req_obj(req)
@@ -829,65 +923,16 @@ async def handle_loop(self):
while True:
try:
- await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05)
+ recv_obj = await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05)
except asyncio.TimeoutError:
- pass
+ recv_obj = None
try:
- for group_req_id_ in list(self.req_id_to_out_inf.keys()):
- req_status = self.req_id_to_out_inf.get(group_req_id_, None)
- if req_status is None:
- continue
-
- token_list = []
- for req in req_status.group_req_objs.shm_req_objs:
- req_id = req.request_id
- read_token_count = 1
- if req.out_tokens_queue.is_full():
- read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
-
- for _ in range(read_token_count):
- if not req.out_tokens_queue.is_empty():
-
- text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
- req.cumlogprob += float(req.shm_logprobs.arr[src_index])
- metadata = {
- "id": int(req.shm_prompt_ids.arr[src_index]),
- "logprob": float(req.shm_logprobs.arr[src_index]),
- "cumlogprob": float(req.cumlogprob) / count_output_tokens,
- "special": special,
- "count_output_tokens": count_output_tokens,
- "prompt_cache_len": req.prompt_cache_len,
- "cpu_prompt_cache_len": req.cpu_prompt_cache_len,
- "disk_prompt_cache_len": req.disk_prompt_cache_len,
- "mtp_accepted_token_num": req.mtp_accepted_token_num,
- }
- if self.args.return_all_prompt_logprobs:
- metadata.update(req.get_all_prompt_metadata())
- if self.args.use_reward_model:
- metadata["score"] = float(req.reward_score)
-
- req.out_tokens_queue.pop_no_ret()
-
- finished_token_index = (
- req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
- )
-
- if finished_token_index != src_index:
- token_list.append((req_id, text, metadata, FinishStatus()))
- else:
- if req.stop_str_matched:
- finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
- else:
- finish_status = FinishStatus(req.finish_status.status)
-
- token_list.append((req_id, text, metadata, finish_status))
- else:
- break
+ if recv_obj is None or isinstance(recv_obj, GenerateResp):
+ await self._handle_recv_generate_request(recv_obj)
+ elif isinstance(recv_obj, GeneralModelToHttpRpcRsp):
+ await self._handle_recv_general_model_to_http_request(recv_obj)
- async with req_status.lock:
- req_status.out_token_info_list.extend(token_list)
- req_status.event.set()
except BaseException as e:
logger.exception(str(e))
raise e
@@ -895,13 +940,189 @@ async def handle_loop(self):
self.recycle_event.set()
return
+ async def _handle_recv_generate_request(self, recv_obj: GenerateReqMeta):
+ for group_req_id_ in list(self.req_id_to_out_inf.keys()):
+ req_status = self.req_id_to_out_inf.get(group_req_id_, None)
+ if req_status is None:
+ continue
+
+ token_list = []
+ for req in req_status.group_req_objs.shm_req_objs:
+ req_id = req.request_id
+ read_token_count = 1
+ if req.out_tokens_queue.is_full():
+ read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
+
+ for _ in range(read_token_count):
+ if not req.out_tokens_queue.is_empty():
+
+ text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
+ req.cumlogprob += float(req.shm_logprobs.arr[src_index])
+ metadata = {
+ "id": int(req.shm_prompt_ids.arr[src_index]),
+ "logprob": float(req.shm_logprobs.arr[src_index]),
+ "cumlogprob": float(req.cumlogprob) / count_output_tokens,
+ "special": special,
+ "count_output_tokens": count_output_tokens,
+ "prompt_cache_len": req.prompt_cache_len,
+ "cpu_prompt_cache_len": req.cpu_prompt_cache_len,
+ "mtp_accepted_token_num": req.mtp_accepted_token_num,
+ }
+ if self.args.return_all_prompt_logprobs:
+ metadata.update(req.get_all_prompt_metadata())
+ if self.args.use_reward_model:
+ metadata["score"] = float(req.reward_score)
+
+ req.out_tokens_queue.pop_no_ret()
+
+ finished_token_index = (
+ req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
+ )
+
+ if finished_token_index != src_index:
+ token_list.append((req_id, text, metadata, FinishStatus()))
+ else:
+ if req.stop_str_matched:
+ finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
+ else:
+ finish_status = FinishStatus(req.finish_status.status)
+
+ if self._routing_shm is not None:
+ _num_moe = int(self._routing_shm.arr[0])
+ _topk = int(self._routing_shm.arr[1])
+ _dtype_id = int(self._routing_shm.arr[2])
+ if _num_moe > 0:
+ routing_meta = req.get_routing_metadata(_num_moe, _topk, dtype_id=_dtype_id)
+ if routing_meta is not None:
+ metadata["routed_experts"] = routing_meta
+
+ token_list.append((req_id, text, metadata, finish_status))
+ else:
+ break
+
+ async with req_status.lock:
+ req_status.out_token_info_list.extend(token_list)
+ req_status.event.set()
+
+ async def _handle_recv_general_model_to_http_request(self, recv_obj: GeneralModelToHttpRpcRsp):
+ assert recv_obj.func_name is not None
+ event = await self.get_event_for_func(recv_obj.func_name)
+ event.result = recv_obj
+ event.set()
+ return
+
+ async def pause_generation(self):
+ # 因为请求是从master node转发到slave node的
+ # 所以只要master暂停了,slave自然暂停。
+ if self.is_pause:
+ return
+ async with self.is_pause_cond:
+ self.is_pause = True
+ while True:
+ await self.abort_request(AbortReq(request_id=None, abort_all=True))
+ running_req_num = len(list(self.req_id_to_out_inf.keys()))
+ if running_req_num == 0:
+ break
+ await asyncio.sleep(1.0)
+
+ async def continue_generation(self):
+ async with self.is_pause_cond:
+ self.is_pause = False
+ self.is_pause_cond.notify_all()
+
+ async def get_event_for_func(self, func_name: str) -> asyncio.Event:
+ if func_name not in self.async_events_per_func:
+ self.async_events_per_func[func_name] = asyncio.Event()
+ return self.async_events_per_func[func_name]
+
+ async def http_to_model_special_request(
+ self, request: GeneralHttpToModelRpcReq, timeout: int = 300
+ ) -> GeneralModelToHttpRpcRsp:
+ event = await self.get_event_for_func(request.func_name)
+ await self.transfer_to_next_module(request)
+ try:
+ await asyncio.wait_for(event.wait(), timeout=timeout)
+ ret = event.result
+
+ except asyncio.TimeoutError:
+ ret = GeneralModelToHttpRpcRsp(success=False, msg="wait for response timeout", func_name=request.func_name)
+ except Exception as e:
+ ret = GeneralModelToHttpRpcRsp(
+ success=False, msg="wait for response error: %s" % str(e), func_name=request.func_name
+ )
+ return ret
+
+ async def flush_cache(self, request: FlushCacheReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="flush_cache", func_args=request)
+ )
+
+ async def release_memory_occupation(self, request: ReleaseMemoryReq):
+ assert len(self.req_id_to_out_inf) == 0, "there are still requests running, cannot release memory occupation"
+ # 暂停接受请求,除非resume
+ await self.pause_generation()
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="release_memory_occupation", func_args=request.tags)
+ )
+
+ async def resume_memory_occupation(self, request: ResumeMemoryReq):
+ ret = await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="resume_memory_occupation", func_args=request.tags)
+ )
+ if ret.success:
+ await self.continue_generation()
+ return ret
+
+ async def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="init_weights_update_group", func_args=request)
+ )
+
+ async def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="destroy_weights_update_group", func_args=request)
+ )
+
+ async def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+
+ if request.abort_all_requests:
+ await self.abort_request(AbortReq(abort_all=True))
+
+ if request.flush_cache:
+ await self.flush_cache(FlushCacheReq())
+
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="update_weights_from_distributed", func_args=request)
+ )
+
+ async def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq) -> Tuple[bool, str]:
+ if request.abort_all_requests:
+ await self.abort_request(AbortReq(abort_all=True))
+
+ if request.flush_cache:
+ await self.flush_cache(FlushCacheReq())
+
+ return await self.http_to_model_special_request(
+ GeneralHttpToModelRpcReq(func_name="update_weights_from_tensor", func_args=request)
+ )
+
class ReqStatus:
- def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None:
+ def __init__(
+ self,
+ group_request_id: int,
+ prompt: str,
+ sampling_params: SamplingParams,
+ multimodal_params: MultimodalParams,
+ req_objs: List[Req],
+ start_time,
+ ) -> None:
self.lock = asyncio.Lock()
self.event = asyncio.Event()
- self.group_req_objs = GroupReqObjs(
+ self.group_req_objs = GenerateReq(
group_req_id=group_request_id,
+ prompt=prompt,
+ sampling_params=sampling_params,
multimodal_params=multimodal_params,
shm_req_objs=req_objs,
time_mark=start_time,
diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py
new file mode 100644
index 0000000000..e04e8871ce
--- /dev/null
+++ b/lightllm/server/io_struct.py
@@ -0,0 +1,195 @@
+from abc import ABC
+from dataclasses import dataclass
+from lightllm.server.core.objs.req import Req
+from lightllm.server.core.objs.sampling_params import SamplingParams
+from lightllm.server.multimodal_params import MultimodalParams
+from typing import List, Optional, Any, Union
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+
+
+@dataclass
+class BaseReq(ABC):
+ def get_req_to_next_node(self):
+ return self
+
+ def get_req_to_next_module(self):
+ return self
+
+
+@dataclass
+class BaseRsp(ABC):
+ success: bool
+ msg: Optional[str]
+
+
+# for next node
+@dataclass
+class GenerateReqMeta(BaseReq):
+ prompt: str
+ sampling_params: SamplingParams
+ multimodal_params: MultimodalParams
+
+
+# for next module
+@dataclass
+class GenerateReqIndex(BaseReq):
+ group_req_id: int
+ multimodal_params: MultimodalParams
+ shm_req_indexes: List[int]
+ time_mark: float
+
+
+@dataclass
+class GenerateReq(BaseReq):
+ group_req_id: int
+ prompt: str
+ sampling_params: SamplingParams
+ multimodal_params: MultimodalParams
+ shm_req_objs: List[Req]
+ time_mark: float
+
+ def get_req_to_next_module(self):
+ # 已经完成跨节点转发,可以释放图片原始资源
+ self.multimodal_params.free()
+ return GenerateReqIndex(
+ group_req_id=self.group_req_id,
+ multimodal_params=self.multimodal_params,
+ shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs],
+ time_mark=self.time_mark,
+ )
+
+ def get_req_to_next_node(self):
+ return GenerateReqMeta(
+ prompt=self.prompt,
+ sampling_params=self.sampling_params,
+ multimodal_params=self.multimodal_params,
+ )
+
+
+@dataclass
+class GenerateResp(BaseReq):
+ pass
+
+
+@dataclass
+class FlushCacheReq(BaseReq):
+ pass
+
+
+@dataclass
+class FlushCacheResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class AbortReq(BaseReq):
+ # 外部调用传入,等同内部的 group_req_id
+ request_id: int = None
+ abort_all: bool = False
+
+
+@dataclass
+class ReleaseMemoryReq(BaseReq):
+ tags: Optional[List[MemoryTag]] = None
+
+
+@dataclass
+class ReleaseMemoryResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class ResumeMemoryReq(BaseReq):
+ tags: Optional[List[MemoryTag]] = None
+
+
+@dataclass
+class ResumeMemoryResp(BaseReq):
+ success: bool
+
+
+@dataclass
+class GeneralHttpToModelRpcReq(BaseReq):
+ func_name: str
+ func_args: Optional[Any] = None
+
+
+@dataclass
+class GeneralModelToHttpRpcRsp(BaseRsp):
+ func_name: str
+ func_rsp: Optional[Any] = None
+
+
+@dataclass
+class InitWeightsUpdateGroupReq(BaseReq):
+ # The master address
+ master_address: str
+ # The master port
+ master_port: int
+ # The rank offset
+ rank_offset: int
+ # The world size
+ world_size: int
+ # The group name
+ group_name: str = "weight_update_group"
+ # The backend
+ backend: str = "nccl"
+
+
+@dataclass
+class InitWeightsUpdateGroupRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class DestroyWeightsUpdateGroupReq(BaseReq):
+ group_name: str = "weight_update_group"
+
+
+@dataclass
+class DestroyWeightsUpdateGroupRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class UpdateWeightsFromDistributedReq(BaseReq):
+ names: List[str]
+ dtypes: List[str]
+ shapes: List[List[int]]
+ # The group name
+ group_name: str = "weight_update_group"
+ # Whether to flush the cache after updating weights
+ flush_cache: bool = True
+ # Whether to abort all requests before updating weights
+ abort_all_requests: bool = False
+ # Optional: Update weight version along with weights
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromDistributedRsp(BaseRsp):
+ pass
+
+
+@dataclass
+class UpdateWeightsFromTensorReq(BaseReq):
+ """Update model weights from tensor input.
+
+ - Tensors are serialized for transmission
+ - Data is structured in JSON for easy transmission over HTTP
+ """
+
+ serialized_named_tensors: List[Union[str, bytes]]
+ # Optional format specification for loading
+ load_format: Optional[str] = None
+ # Whether to flush the cache after updating weights
+ flush_cache: bool = True
+ # Whether to abort all requests before updating weights
+ abort_all_requests: bool = False
+ # Optional: Update weight version along with weights
+ weight_version: Optional[str] = None
+
+
+@dataclass
+class UpdateWeightsFromTensorRsp(BaseRsp):
+ pass
diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py
index 0a7dec0005..b771ac1a01 100644
--- a/lightllm/server/multi_level_kv_cache/manager.py
+++ b/lightllm/server/multi_level_kv_cache/manager.py
@@ -12,7 +12,7 @@
from queue import Queue
from typing import List
from lightllm.server.core.objs import ShmReqManager, Req, StartArgs
-from lightllm.server.core.objs.io_objs import GroupReqIndexes
+from lightllm.server.io_struct import GenerateReqIndex, BaseReq
from lightllm.utils.graceful_utils import graceful_registry
from .cpu_cache_client import CpuKvCacheClient
from lightllm.utils.log_utils import init_logger
@@ -135,7 +135,7 @@ def _disk_cache_match(self, token_hash_list: List[int], all_pages: List[int]) ->
self.cpu_cache_client.lock.release()
return all_pages, len(new_page_indexes)
- def _handle_group_req_multi_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float):
+ def _handle_group_req_multi_cache_match(self, group_req_indexes: GenerateReqIndex, start_time: float):
"""
match cpu cache and disk cache pages
"""
@@ -218,8 +218,12 @@ def recv_loop(self):
try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(recv_max_count):
- recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- assert isinstance(recv_obj, GroupReqIndexes)
+ recv_obj: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if not isinstance(recv_obj, GenerateReqIndex):
+ # RL 等控制类 BaseReq 透传给 router 处理,避免在此被静默丢弃
+ if isinstance(recv_obj, BaseReq):
+ self.send_to_router.send_pyobj(recv_obj, protocol=pickle.HIGHEST_PROTOCOL)
+ continue
recv_objs.append(recv_obj)
start_time = recv_obj.time_mark
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index 6210628751..b05f34bd99 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -70,9 +70,11 @@ def read(self):
assert self._preload_data is not None
ans = self._preload_data
self._preload_data = None
- self._data = None
return ans
+ def free(self):
+ self._data = None
+
def to_dict(self):
ret = {}
ret["uuid"] = self.uuid
@@ -155,9 +157,11 @@ def read(self):
assert self._preload_data is not None
ans = self._preload_data
self._preload_data = None
- self._data = None
return ans
+ def free(self):
+ self._data = None
+
def to_dict(self):
ret = {}
ret["uuid"] = self.uuid
@@ -211,3 +215,10 @@ def to_origin_dict(self):
ret["images"] = [i.to_origin_dict() for i in self.images]
ret["audios"] = [a.to_origin_dict() for a in self.audios]
return ret
+
+ def free(self):
+ for image in self.images:
+ image.free()
+ for audio in self.audios:
+ audio.free()
+ return
diff --git a/lightllm/server/req_id_generator.py b/lightllm/server/req_id_generator.py
index f7c099c292..bc81d835b6 100644
--- a/lightllm/server/req_id_generator.py
+++ b/lightllm/server/req_id_generator.py
@@ -30,7 +30,8 @@ def __init__(self):
self.current_id.arr[0] = 0
self.current_id.arr[1] = 0
self.lock = AtomicShmLock(f"{get_unique_server_name()}_req_id_gen_lock")
- self._wait_all_workers_ready()
+ if self.args.httpserver_workers > 1:
+ self._wait_all_workers_ready()
logger.info("ReqIDGenerator init finished")
def _wait_all_workers_ready(self):
diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
index 73c6dba54d..082f6bec08 100644
--- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py
@@ -467,6 +467,10 @@ def clear_tree_nodes(self):
self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
return
+ def flush_cache(self):
+ self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
+ return
+
def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional[LinearAttPagedTreeNode]:
assert not node.is_big_page_node()
iter_node = node
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index 88b099459b..955e808600 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -425,6 +425,10 @@ def clear_tree_nodes(self):
self.refed_tokens_num.arr[0] = 0
return
+ def flush_cache(self):
+ self.free_radix_cache_to_get_enough_token(need_token_num=self.total_token_num)
+ return
+
def dec_node_ref_counter(self, node: TreeNode):
if node is None:
return
diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py
index f5e0b8df9a..83972a3d34 100644
--- a/lightllm/server/router/manager.py
+++ b/lightllm/server/router/manager.py
@@ -5,6 +5,7 @@
import pickle
import inspect
import setproctitle
+import rpyc
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import zmq
@@ -17,7 +18,6 @@
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from lightllm.server.core.objs.io_objs import (
- GroupReqIndexes,
AbortedReqCmd,
StopStrMatchedReqCmd,
)
@@ -29,6 +29,18 @@
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
+from lightllm.server.io_struct import (
+ BaseReq,
+ GenerateReqIndex,
+ FlushCacheReq,
+ FlushCacheResp,
+ ReleaseMemoryReq,
+ ReleaseMemoryResp,
+ ResumeMemoryReq,
+ ResumeMemoryResp,
+ GeneralHttpToModelRpcReq,
+ GeneralModelToHttpRpcRsp,
+)
from lightllm.common.kv_cache_mem_manager import ReadOnlyStaticsMemoryManager
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
@@ -356,8 +368,13 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]:
ans = []
if self.running_batch is None:
return ans
- for req in self.running_batch.reqs:
- if req.is_aborted and req._router_aborted is False:
+ aborted_req_mask = torch.tensor(
+ [req.is_aborted for req in self.running_batch.reqs], dtype=torch.bool, device="cpu"
+ )
+ if self.is_multinode_tp:
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
+ for req, is_aborted in zip(self.running_batch.reqs, aborted_req_mask.numpy()):
+ if is_aborted and req._router_aborted is False:
req._router_aborted = True
ans.append(req)
return ans
@@ -406,7 +423,7 @@ def get_used_tokens(self, dp_index):
else:
return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
- def _add_req(self, group_req_indexes: GroupReqIndexes):
+ def _add_req(self, group_req_indexes: BaseReq):
req_group = []
for req_index in group_req_indexes.shm_req_indexes:
req = self.shm_req_manager.get_req_obj_by_index(req_index)
@@ -454,9 +471,22 @@ def _multinode_tp_generate_new_batch(self):
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
req_id_select_mark = [1 for _ in range(len(req_ids))]
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
+ # TODO: 这里可以合成一个 allreudce,req_id_select_mark + aborted_req_mask
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
+ aborted_req_mask = torch.tensor(
+ [req.is_aborted for req in new_batch.reqs], dtype=torch.bool, device="cpu"
+ )
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
back_req_list = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
+ for req_id, select, is_aborted in zip(
+ req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy()
+ ):
+ # 释放多节点abort 请求,如果select == 0, is_aborted 一定为False
+ if is_aborted and select == 1:
+ req = new_batch.pop_req(req_id)
+ self.req_queue.free_aborted_req(req)
+ self.shm_req_manager.put_back_req_obj(req)
+ continue
if select == 0:
req = new_batch.pop_req(req_id)
back_req_list.append(req)
@@ -472,23 +502,28 @@ def _multinode_tp_generate_new_batch(self):
else:
req_ids = [None for _ in range(req_num)]
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
- all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
+ # all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
+ id_to_req_obj = {req.request_id: req for req in self.req_queue.waiting_req_list}
req_id_select_mark = []
+ aborted_req_mask = []
for req_id in req_ids:
- req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
+ req_id_select_mark.append(1 if req_id in id_to_req_obj else 0)
+ aborted_req_mask.append(id_to_req_obj[req_id].is_aborted if req_id in id_to_req_obj else False)
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
- select_req_ids = []
- for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
- if select == 1:
- select_req_ids.append(req_id)
-
+ aborted_req_mask = torch.tensor(aborted_req_mask, dtype=torch.bool, device="cpu")
+ dist.all_reduce(aborted_req_mask, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
select_reqs = []
- for req_id in select_req_ids:
- for req in self.req_queue.waiting_req_list:
- if req.request_id == req_id:
- select_reqs.append(req)
-
+ for req_id, select, is_aborted in zip(
+ req_ids, req_id_select_mark.numpy(), aborted_req_mask.numpy()
+ ):
+ if select == 1:
+ req = id_to_req_obj[req_id]
+ if is_aborted:
+ self.req_queue.free_aborted_req(req)
+ self.shm_req_manager.put_back_req_obj(req)
+ continue
+ select_reqs.append(req)
for req in select_reqs:
self.req_queue.waiting_req_list.remove(req)
if select_reqs:
@@ -509,13 +544,17 @@ async def _recv_new_reqs_and_schedule(self):
self.recv_max_count = 64
try:
+ # 多机tp需要广播给其他node的请求
+ special_reqs = []
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(self.recv_max_count):
- recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ if isinstance(recv_req, GenerateReqIndex):
self._add_req(recv_req)
+ elif isinstance(recv_req, GeneralHttpToModelRpcReq):
+ special_reqs.append(recv_req)
else:
- assert False, f"Error Req Inf {recv_req}"
+ raise ValueError(f"Unknown request type: {type(recv_req)}")
# 当队列中存在较多的请求时,将一次接受的数量上调
self.recv_max_count = min(int(self.recv_max_count * 1.3), 256)
@@ -524,6 +563,8 @@ async def _recv_new_reqs_and_schedule(self):
# 当队列已经开始清空的时候,将一次接受的数量下调
self.recv_max_count = 64
+ await self._process_special_reqs(special_reqs)
+
if self.is_multinode_tp:
self._multinode_tp_generate_new_batch()
else:
@@ -531,6 +572,50 @@ async def _recv_new_reqs_and_schedule(self):
self._generate_new_batch()
return
+ async def _process_special_reqs(self, special_reqs: List[BaseReq]):
+ if self.is_multinode_tp:
+ special_reqs = self.broadcast_reqs_to_other_nodes(special_reqs)
+ for req in special_reqs:
+ assert isinstance(req, GeneralHttpToModelRpcReq), "special request must be GeneralHttpToModelRpcReq"
+ await self.forward_to_model(req)
+
+ def broadcast_reqs_to_other_nodes(self, reqs: List[BaseReq]):
+ req_num = len(reqs)
+ if self.node_rank == 0:
+ req_nums = [len(reqs)]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ else:
+ req_nums = [None]
+ dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
+ req_num = req_nums[0]
+ if req_num > 0:
+ reqs = [None for _ in range(req_num)]
+ dist.broadcast_object_list(reqs, src=0, group=self.mulitnode_group)
+ return reqs
+
+ async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> None:
+ forward_to_model_tasks = []
+ for model_rpc_client in self.model_rpc_clients:
+ forward_to_model_tasks.append(model_rpc_client.forward_to_model(req))
+ all_ret = await asyncio.gather(*forward_to_model_tasks)
+ succes = all(ret.success for ret in all_ret)
+ ret = all_ret[0]
+ ret.success = succes
+ if self.is_multinode_tp:
+ output_list = [None for _ in self.nnodes] if self.node_rank == 0 else None
+ dist.gather_object(ret, output_list, dst=0, group=self.mulitnode_group)
+ for res in output_list:
+ res: GeneralModelToHttpRpcRsp
+ if not res.success:
+ ret = res
+ break
+
+ if self.node_rank == 0:
+ self.send_to_detokenization.send_pyobj(ret, protocol=pickle.HIGHEST_PROTOCOL)
+
def clean_up(self):
return
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 3a27c082de..b5a68299ca 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -24,6 +24,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
+from lightllm.common.basemodel import routing_manager as _routing_mgr
logger = init_logger(__name__)
@@ -122,6 +123,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
return req_objs
+ def _extract_routing_data(self, req: "InferReq"):
+ if req.shm_req.shm_routing_num_tokens > 0:
+ return
+ mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]
+ mgr = _routing_mgr.g_routing_capture_manager
+ routing_data = mgr.extract_routing_data(mem_indexes)
+ req.shm_req.create_routing_data_shm_array(mgr.num_moe_layers, req.cur_kv_len, mgr.topk, np_dtype=mgr.np_dtype)
+ req.shm_req.shm_routing_data.arr[:] = routing_data
+ req.shm_req.shm_routing_data.detach_shm()
+
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
if self.radix_cache is None:
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
@@ -444,6 +455,9 @@ def __init__(
if len(self.allowed_token_ids) == 0:
self.allowed_token_ids = None
+ # if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample)
+ self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list()
+
# p d mode use params
if self.shm_param.move_kv_to_decode_node.exists:
self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict()
@@ -456,6 +470,11 @@ def __init__(
logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids")
self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size]
+ if len(self.invalid_token_ids) > 0:
+ if not all(e < vocab_size for e in self.invalid_token_ids):
+ logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids")
+ self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size]
+
# nixl decode node information
if self.shm_param.nixl_params.data_len > 0:
self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get())
@@ -929,6 +948,8 @@ def handle(
shm_req.shm_cur_output_len = self.output_len
if finish_status.is_finished():
+ if _routing_mgr.g_routing_capture_manager is not None:
+ g_infer_context._extract_routing_data(req_obj)
shm_req.finish_token_index = shm_req.input_len + self.output_len - 1
shm_req.finish_status = req_obj.finish_status
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 048869f860..1357e13462 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -4,7 +4,7 @@
import time
import threading
import torch.distributed as dist
-from typing import List, Tuple, Callable, Optional
+from typing import List, Tuple, Callable, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.log_utils import init_logger
@@ -19,7 +19,7 @@
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput
from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify
-from lightllm.utils.dist_utils import init_distributed_env
+from lightllm.utils.dist_utils import init_distributed_env, init_custom_process_group
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs import ShmReqManager, StartArgs
from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd
@@ -34,8 +34,11 @@
enable_radix_tree_timer_merge,
get_radix_tree_merge_update_delta,
)
+from lightllm.utils.serializer import LocalSerializedTensor, MultiprocessingSerializer
+from lightllm.utils.patch_torch import monkey_patch_torch_reductions
+from lightllm.utils.tensor_bucket import FlattenedTensorBucket, FlattenedTensorMetadata
+from lightllm.distributed import dist_group_manager
from lightllm.distributed.communication_op import (
- dist_group_manager,
all_gather_into_tensor,
all_reduce,
broadcast,
@@ -49,7 +52,16 @@
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
+from lightllm.common.basemodel import routing_manager as _routing_mgr
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
from .multi_level_kv_cache import MultiLevelKvCacheModule
+from lightllm.server.io_struct import (
+ FlushCacheReq,
+ InitWeightsUpdateGroupReq,
+ DestroyWeightsUpdateGroupReq,
+ UpdateWeightsFromDistributedReq,
+ UpdateWeightsFromTensorReq,
+)
class ModeBackend:
@@ -122,6 +134,8 @@ def init_model(self, kvargs):
)
dist_group_manager.create_groups(group_size=group_size) # set the default group
+ self._model_update_group = {}
+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
# 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在
@@ -315,14 +329,15 @@ def init_mtp_draft_model(self, main_kvargs: dict):
os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1"
- if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]:
+ if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att", "qwen3next_vanilla"]:
num_mtp_modules = self.args.mtp_step
- elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]:
+ elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "qwen3next_eagle"]:
num_mtp_modules = 1
else:
assert False, f"error mtp mode {self.args.mtp_mode}"
for i in range(num_mtp_modules):
+ # Get MTP model config first to calculate mem_layer_start
mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i])
model_type = mtp_model_cfg.get("model_type", "")
mtp_model_kvargs = {
@@ -337,7 +352,7 @@ def init_mtp_draft_model(self, main_kvargs: dict):
"data_type": main_kvargs.get("data_type", "float16"),
"graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16),
"graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196),
- "disable_cudagraph": main_kvargs.get("disable_cudagraph", False),
+ "disable_cudagraph": True, # Disable CUDA graphs for MTP draft models
"mem_fraction": main_kvargs["mem_fraction"],
"batch_max_tokens": main_kvargs.get("batch_max_tokens", None),
"quant_type": main_kvargs.get("quant_type", None),
@@ -345,6 +360,7 @@ def init_mtp_draft_model(self, main_kvargs: dict):
"run_mode": "normal",
"main_model": self.model,
"mtp_previous_draft_models": self.draft_models.copy(),
+ "mtp_index": i,
}
# Select MTP model class based on model type
@@ -367,6 +383,191 @@ def init_mtp_draft_model(self, main_kvargs: dict):
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
return
+ def flush_cache(self, request: FlushCacheReq):
+ if self.radix_cache is not None:
+ self.radix_cache.flush_cache()
+ return True, "Succeeded to flush cache."
+
+ def release_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.model.req_manager.free_all()
+ self.model.mem_manager.free_all()
+ self.model.release_memory_occupation(tags)
+ self.flush_cache(request=None)
+ return True, "Succeeded to release memory occupation."
+ except Exception as e:
+ self.logger.error(f"release memory occupation failed: {str(e)}")
+ return False, f"release memory occupation failed: {str(e)}"
+
+ def resume_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.model.resume_memory_occupation(tags)
+ return True, "Succeeded to resume memory occupation."
+ except Exception as e:
+ self.logger.error(f"resume memory occupation failed: {str(e)}")
+ return False, f"resume memory occupation failed: {str(e)}"
+
+ def init_weights_update_group(self, request: InitWeightsUpdateGroupReq):
+ assert torch.distributed.is_initialized(), "Default torch process group must be initialized"
+
+ assert request.group_name != "", "Group name cannot be empty"
+ rank_offset = request.rank_offset
+ rank = rank_offset + self.rank_in_dp
+ world_size = request.world_size
+ group_name = request.group_name
+ self.logger.info(
+ f"init custom process group: master_address={request.master_address}, master_port={request.master_port}, "
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, "
+ f" backend={request.backend}"
+ )
+
+ try:
+ if group_name in self._model_update_group:
+ raise ValueError(f"Process group with name {group_name} already exists.")
+
+ self._model_update_group[group_name] = init_custom_process_group(
+ backend=request.backend,
+ init_method=f"tcp://{request.master_address}:{request.master_port}",
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+ return True, "Succeeded to initialize custom process group."
+
+ except Exception as e:
+ message = f"Failed to initialize custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def destroy_weights_update_group(self, request: DestroyWeightsUpdateGroupReq):
+ try:
+ if request.group_name in self._model_update_group:
+ pg = self._model_update_group.pop(request.group_name)
+ torch.distributed.destroy_process_group(pg)
+ return True, "Succeeded to destroy custom process group."
+ else:
+ return False, "The group to be destroyed does not exist."
+ except Exception as e:
+ message = f"Failed to destroy custom process group: {e}."
+ self.logger.error(message)
+ return False, message
+
+ def update_weights_from_distributed(self, request: UpdateWeightsFromDistributedReq):
+ """
+ Update specific parameter in the model weights online
+ through `_model_update_group` process group.
+
+ Args:
+ name: the name of the parameter to be updated.
+ dtype: the data type of the parameter to be updated.
+ shape: the shape of the parameter to be updated.
+ """
+
+ assert request.group_name in self._model_update_group, (
+ f"Group {request.group_name} not in {list(self._model_update_group.keys())}. "
+ "Please call `init_weights_update_group` first."
+ )
+
+ try:
+ weights = []
+ handles = []
+ for name, dtype, shape in zip(request.names, request.dtypes, request.shapes):
+ target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ weight = torch.empty(shape, dtype=target_dtype, device="cuda")
+ handles.append(
+ torch.distributed.broadcast(
+ weight,
+ src=0,
+ group=self._model_update_group[request.group_name],
+ async_op=True,
+ )
+ )
+ weights.append((name, weight))
+ for handle in handles:
+ handle.wait()
+
+ self.model.load_weights(weights)
+ return True, "Succeeded to update parameter online from distributed."
+
+ except Exception as e:
+ error_msg = (
+ f"Failed to update parameter online: {e}. "
+ f"The full weights of the ModelRunner are partially updated. "
+ f"Please discard the whole weights."
+ )
+ self.logger.error(error_msg)
+ return False, error_msg
+
+ def _update_weights_from_flattened_bucket(
+ self,
+ flattened_tensor_bucket_dict,
+ ):
+ """Handle flattened bucket format for weight updates"""
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
+ metadata = flattened_tensor_bucket_dict["metadata"]
+
+ # Convert metadata dict to our format
+ converted_metadata = []
+ for meta in metadata:
+ converted_meta = FlattenedTensorMetadata(
+ name=meta.name,
+ shape=meta.shape,
+ dtype=meta.dtype,
+ start_idx=meta.start_idx,
+ end_idx=meta.end_idx,
+ numel=meta.numel,
+ )
+ converted_metadata.append(converted_meta)
+
+ # Create bucket and reconstruct tensors
+ bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=converted_metadata)
+ reconstructed_tensors = bucket.reconstruct_tensors()
+
+ named_tensors = {name: tensor for name, tensor in reconstructed_tensors}
+
+ # Load the reconstructed tensors using the standard method
+ self.model.load_weights(named_tensors)
+
+ return True, "Succeeded to update parameter online from flattened bucket tensor."
+
+ def update_weights_from_tensor(self, request: UpdateWeightsFromTensorReq):
+ try:
+ monkey_patch_torch_reductions()
+ if request.load_format == "flattened_bucket":
+ # Handle flattened bucket format
+ serialized_named_tensors = MultiprocessingSerializer.deserialize(
+ request.serialized_named_tensors[self.rank_in_dp]
+ )
+ return self._update_weights_from_flattened_bucket(flattened_tensor_bucket_dict=serialized_named_tensors)
+
+ # We need to get device after patch otherwise the device would be wrong
+ self.device_module = torch.get_device_module("cuda")
+ infered_device = self.device_module.current_device()
+
+ named_tensors = MultiprocessingSerializer.deserialize(request.serialized_named_tensors[self.rank_in_dp])
+
+ def _unwrap_tensor(tensor, tp_rank, device):
+ if isinstance(tensor, LocalSerializedTensor):
+ tensor = tensor.get(tp_rank)
+ clone = tensor.to(device).clone()
+ del tensor # free the ipc tensor
+ return clone
+
+ named_tensors = {
+ name: _unwrap_tensor(tensor, tp_rank=self.rank_in_dp, device=infered_device)
+ for name, tensor in named_tensors
+ }
+
+ self.model.load_weights(named_tensors)
+
+ return True, "Succeeded to update parameter online from tensor."
+
+ except Exception as e:
+ message = f"Failed to update parameter online from tensor. Reason: {e}."
+ self.logger.error(message)
+
+ return False, message
+
def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor):
"""
这个函数会把next token id和logprobs保存到pinned memory中
@@ -617,7 +818,7 @@ def _get_classed_reqs(
paused_reqs.append(req_obj)
continue
- if req_obj.infer_aborted or req_obj.finish_status.is_finished():
+ if req_obj.finish_status.is_finished():
if support_overlap:
# 延迟处理
req_obj.filter_mark = True
@@ -827,6 +1028,18 @@ def _sample_and_scatter_token(
)
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu
+ def _flush_routing_to_kv_buffer(self, mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
+ """Scatter captured routing data from capture buffer to KV-indexed GPU buffer.
+
+ Must be called AFTER model.forward() completes. mem_indexes should be the
+ original (unpadded) tensor — either CPU or CUDA.
+ """
+ if _routing_mgr.g_routing_capture_manager is not None and mem_indexes is not None:
+ if not mem_indexes.is_cuda:
+ mem_indexes = mem_indexes.cuda(non_blocking=True)
+ num_tokens = mem_indexes.shape[0]
+ _routing_mgr.g_routing_capture_manager.flush_to_routing_buffer(mem_indexes, num_tokens, microbatch_index)
+
def _dp_all_gather_prefill_and_decode_req_num(
self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]
) -> Tuple[np.ndarray, np.ndarray]:
diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
index 60045fab6c..e068c00c76 100644
--- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
@@ -109,6 +109,7 @@ def prefill_normal(
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -140,6 +141,7 @@ def prefill_normal(
extra_post_req_handle_func=self.extra_post_req_handle_func,
nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func,
)
+
# 第四阶段
event_pack.notify_pre_post_handle()
return
@@ -152,6 +154,7 @@ def decode_normal(
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -190,6 +193,7 @@ def prefill_mtp(
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits,
b_req_idx=model_input.b_req_idx,
@@ -244,6 +248,7 @@ def decode_mtp(
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
b_mtp_index_cpu = model_input.b_mtp_index
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
# verify the next_token_ids
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
@@ -266,6 +271,7 @@ def decode_mtp(
key="mtp_accept_len",
gpu_tensor=mtp_accept_len,
)
+
verify_event = torch.cuda.Event()
verify_event.record()
diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
index 5a179cb620..ebc55b7ef4 100644
--- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py
@@ -40,8 +40,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq
)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
-
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
logits = model_output.logits
batch_idx, run_reqs = self._diverse_copy(
diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
index c83e8cd4a5..c9484dba6f 100644
--- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
+++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py
@@ -151,6 +151,7 @@ def prefill_normal(
run_reqs_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
if run_reqs_num > 0:
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits[:run_reqs_num],
@@ -198,6 +199,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq
run_reqs_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
if run_reqs_num > 0:
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
logits=model_output.logits[:run_reqs_num],
@@ -246,6 +248,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
@@ -319,6 +323,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
@@ -373,6 +379,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
req_num = len(run_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output: ModelOutput = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num]
logits = model_output.logits[0:req_num, :]
b_req_idx = model_input.b_req_idx[0:req_num]
@@ -438,6 +445,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output = self.model.forward(model_input)
+ self._flush_routing_to_kv_buffer(model_input.mem_indexes)
mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None
if req_num > 0:
logits = model_output.logits[0:req_num, :]
@@ -646,6 +654,8 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
) = padded_overlap_prepare_prefill_inputs(prefill_reqs)
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
req_num0, req_num1 = len(run_reqs0), len(run_reqs1)
@@ -747,8 +757,9 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
b_mtp_index_cpu0 = model_input0.b_mtp_index
b_mtp_index_cpu1 = model_input1.b_mtp_index
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
-
model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1)
+ self._flush_routing_to_kv_buffer(model_input0.mem_indexes, microbatch_index=0)
+ self._flush_routing_to_kv_buffer(model_input1.mem_indexes, microbatch_index=1)
logits0 = model_output0.logits
logits1 = model_output1.logits
run_reqs = run_reqs0 + run_reqs1
@@ -788,6 +799,20 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
)
all_next_token_ids.append(next_token_ids)
+ # Copy accepted buffer states back to buffer[0] for MTP
+ # Only copy when accept_len > 1
+ mask = mtp_accept_len > 1
+ if mask.sum() > 0:
+ actual_req_idxes = b_req_idx[b_req_mtp_start_loc[mask]]
+ src_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[
+ actual_req_idxes, mtp_accept_len[mask] - 1
+ ]
+ dst_buffer_indexes = g_infer_context.req_manager.req_to_buffer_index[actual_req_idxes, 0]
+ if hasattr(g_infer_context.req_manager.buffer_mem_manager, "copy_buffer_p2p"):
+ g_infer_context.req_manager.buffer_mem_manager.copy_buffer_p2p(
+ src_buffer_indexes, dst_buffer_indexes
+ )
+
verify_event = torch.cuda.Event()
verify_event.record()
diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
index f3ad03662e..406c889b72 100644
--- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
+++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py
@@ -1,7 +1,8 @@
import torch
from typing import List, Tuple
-from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
-from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache
+from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty
+from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache
+from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.utils.envs_utils import get_env_start_args
@@ -15,7 +16,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
b_top_ks,
b_length_penalty_param,
b_mask_eos_reqs,
+ invalid_token_ids,
+ cu_invalid_token_num,
is_all_greedy,
+ has_invalid_token_ids,
skip_top_k,
skip_top_p,
exist_req_use_random_seed,
@@ -63,6 +67,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
eos_ids=eos_ids,
sampling_params_manager=sampling_params_manager,
)
+
+ if has_invalid_token_ids:
+ apply_invalid_token_ids(
+ Logits=logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ )
+
logits.div_(b_temperatures.view((-1, 1)))
probs = torch.softmax(logits, dim=-1)
@@ -152,6 +164,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
skip_top_p = True
exist_req_use_random_seed = False
+ # invalid token ids
+ invalid_token_ids: List[int] = []
+ has_invalid_token_ids = False
+ cu_invalid_token_num = [0]
+ invalid_token_num_start = 0
+
for i, req_obj in enumerate(reqs):
sample_param = req_obj.sampling_param
shm_param = sample_param.shm_param
@@ -173,6 +191,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
if req_obj.generator is not None:
exist_req_use_random_seed = True
req_idxes.append(req_obj.req_idx)
+ invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids)
+ cu_invalid_token_num.append(invalid_token_num_start)
+ if len(req_obj.sampling_param.invalid_token_ids) > 0:
+ has_invalid_token_ids = True
+ invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids)
req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32)
temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32)
@@ -183,6 +206,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
)
mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool)
+ if has_invalid_token_ids:
+ invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True)
+ cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True)
+
return (
req_idxes_cpu.cuda(non_blocking=True),
temperatures_cpu.cuda(non_blocking=True),
@@ -190,7 +217,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
top_ks_cpu.cuda(non_blocking=True),
length_penalty_param_cpu.cuda(non_blocking=True),
mask_eos_reqs_cpu.cuda(non_blocking=True),
+ invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
+ cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None,
is_all_greedy,
+ has_invalid_token_ids,
skip_top_k,
skip_top_p,
exist_req_use_random_seed,
diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py
index 408b173371..2060753ae6 100644
--- a/lightllm/server/router/model_infer/model_rpc.py
+++ b/lightllm/server/router/model_infer/model_rpc.py
@@ -37,6 +37,8 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.utils.torch_memory_saver_utils import MemoryTag
+from lightllm.server.io_struct import GeneralHttpToModelRpcReq, GeneralModelToHttpRpcRsp
logger = init_logger(__name__)
@@ -128,6 +130,35 @@ def exposed_init_model(self, kvargs):
def exposed_get_max_total_token_num(self):
return self.backend.get_max_total_token_num()
+ def release_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.backend.release_memory_occupation(tags)
+ return True
+ except BaseException as e:
+ logger.exception(f"release memory occupation failed: {str(e)}")
+ return False
+
+ def resume_memory_occupation(self, tags: List[MemoryTag]):
+ try:
+ self.backend.resume_memory_occupation(tags)
+ return True
+ except BaseException as e:
+ logger.exception(f"resume memory occupation failed: {str(e)}")
+ return False
+
+ def exposed_forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ try:
+ req = obtain(req)
+ if self.backend is None or not hasattr(self.backend, req.func_name):
+ raise ValueError(f"Backend does not support function {req.func_name}")
+ success, ret = getattr(self.backend, req.func_name)(req.func_args)
+ return GeneralModelToHttpRpcRsp(success=success, msg=str(ret), func_name=req.func_name, func_rsp=ret)
+ except BaseException as e:
+ logger.exception(f"forward to model backend failed: {str(e)}")
+ return GeneralModelToHttpRpcRsp(
+ success=False, msg=f"forward to model backend failed: {str(e)}", func_name=req.func_name
+ )
+
class ModelRpcClient:
def __init__(self, conn):
@@ -151,6 +182,7 @@ async def _func(*args, **kwargs):
self._init_model = async_wrap(self.conn.root.init_model)
self._get_max_total_token_num = async_wrap(self.conn.root.get_max_total_token_num)
+ self._forward_to_model = async_wrap(self.conn.root.forward_to_model)
return
async def init_model(self, kvargs):
@@ -162,6 +194,10 @@ async def get_max_total_token_num(self):
ans = self._get_max_total_token_num()
return obtain(await ans)
+ async def forward_to_model(self, req: GeneralHttpToModelRpcReq) -> GeneralModelToHttpRpcRsp:
+ ans = self._forward_to_model(req)
+ return obtain(await ans)
+
def _init_env(
args,
@@ -222,7 +258,11 @@ async def start_model_process(
success_event,
),
)
- proc.start()
+ from lightllm.utils.torch_memory_saver_utils import TorchMemorySaverWrapper
+
+ torch_memory_saver = TorchMemorySaverWrapper(args.enable_torch_memory_saver)
+ with torch_memory_saver.configure_subprocess():
+ proc.start()
# Use asyncio.to_thread to make the blocking wait non-blocking
await asyncio.to_thread(success_event.wait, timeout=40)
diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py
index 73113a59b8..e1e2479c86 100644
--- a/lightllm/server/router/req_queue/base_queue.py
+++ b/lightllm/server/router/req_queue/base_queue.py
@@ -33,6 +33,17 @@ def free_aborted_req_cpu_cache_pages(self, req: Req):
req.cpu_cache_match_page_indexes.clear()
self.router.cpu_cache_client.lock.release()
+ def free_aborted_req(self, req: Req):
+ # 为了让http server 能正常返回请求,还没有开始推理的请求,直接设置结束,返回空字符串
+ input_len = req.input_len
+ req.link_prompt_ids_shm_array()
+ req.link_logprobs_shm_array()
+ req.candetoken_out_len = 1
+ req.finish_token_index = input_len
+ req.shm_prompt_ids.arr[input_len] = self.args.eos_id[0]
+ req.shm_logprobs.arr[input_len] = 0
+ req.finish_status.set_status(FinishStatus.FINISHED_ABORTED)
+
def extend(self, req_group: List[Req]):
for req in req_group:
req.sample_params.suggested_dp_index = self.dp_index
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py
index 884b5930b0..24f017d95c 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py
@@ -84,8 +84,8 @@ def generate_new_batch(self, current_batch: Batch):
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
- # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
+ if req.is_aborted and not self.router.is_multinode_tp:
+ # 由于管理的复杂性,只有没有被调度运行过的单节点请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
abort_req_list.append(req)
@@ -104,6 +104,7 @@ def generate_new_batch(self, current_batch: Batch):
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
index 3b831c92a6..0b17bbd1c0 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py
@@ -75,7 +75,7 @@ def generate_new_batch(self, current_batch: Batch):
waiting_queue = self.waiting_req_list
for req in waiting_queue:
- if req.is_aborted:
+ if req.is_aborted and not self.router.is_multinode_tp:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
@@ -95,6 +95,7 @@ def generate_new_batch(self, current_batch: Batch):
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
index 4c2ebf7c00..1bfb8fc59e 100644
--- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
+++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py
@@ -41,7 +41,7 @@ def generate_new_batch(self, current_batch: Batch):
abort_req_list = []
aborted_count = 0
for req in self.waiting_req_list:
- if req.is_aborted:
+ if req.is_aborted and not self.router.is_multinode_tp:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏
aborted_count += 1
@@ -58,6 +58,7 @@ def generate_new_batch(self, current_batch: Batch):
req: Req = req
logger.debug(f"router abort req id {req.request_id} shm_index: {req.index_in_shm_mem}")
self.free_aborted_req_cpu_cache_pages(req)
+ self.free_aborted_req(req)
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py
index a73823b8b7..e5f731df5f 100644
--- a/lightllm/server/router/req_queue/dp_base_queue.py
+++ b/lightllm/server/router/req_queue/dp_base_queue.py
@@ -27,6 +27,12 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
self.reqs_waiting_for_dp_index: List[List[Req]] = []
return
+ def free_aborted_req(self, req: Req):
+ dp_index = req.sample_params.suggested_dp_index
+ assert dp_index >= 0 and dp_index < self.dp_size_in_node
+ self.inner_queues[dp_index].free_aborted_req(req)
+ return
+
def get_dp_queue(self, dp_index: int):
assert dp_index < self.dp_size_in_node, "dp index out of range"
return self.inner_queues[dp_index]
diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py
index 1dffdaf681..6b013c4c63 100644
--- a/lightllm/server/visualserver/manager.py
+++ b/lightllm/server/visualserver/manager.py
@@ -10,7 +10,6 @@
import threading
import collections
from typing import List
-from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
from lightllm.server.core.objs import ShmReqManager, StartArgs
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -21,6 +20,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.io_struct import BaseReq, GenerateReqIndex
from rpyc.utils.classic import obtain
@@ -91,7 +91,7 @@ async def wait_to_model_ready(self):
await asyncio.gather(*init_model_ret)
return
- def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[ImageItem]:
+ def get_need_infer_images(self, group_req_indexes: GenerateReqIndex) -> List[ImageItem]:
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
is_aborted = shm_req.is_aborted
disable_prompt_cache = shm_req.sample_params.disable_prompt_cache
@@ -121,7 +121,7 @@ def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[Imag
return images_need_infer
- async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes):
+ async def handle_group_indexes(self, group_req_indexes: GenerateReqIndex):
images_need_infer = self.get_need_infer_images(group_req_indexes)
if len(images_need_infer) == 0:
@@ -174,13 +174,16 @@ async def infer_images(self, dp_index: int, images, events):
async def loop_for_netio_req(self):
try:
while True:
- recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj)
- if isinstance(recv_req, GroupReqIndexes):
+ recv_req: BaseReq = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj)
+ if isinstance(recv_req, GenerateReqIndex):
logger.info(
f"visual recv req id {recv_req.group_req_id} "
f"img count {len(recv_req.multimodal_params.images)}"
)
asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req))
+ elif isinstance(recv_req, BaseReq):
+ # RL 等控制类 BaseReq 透传给下一模块,最终由 router 处理
+ self.send_to_next_module.send_pyobj(recv_req, protocol=pickle.HIGHEST_PROTOCOL)
else:
assert False, f"Error Req Inf {recv_req}"
except Exception as e:
diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py
index 5b9705ed0e..87aae86e4f 100644
--- a/lightllm/utils/dist_utils.py
+++ b/lightllm/utils/dist_utils.py
@@ -80,12 +80,15 @@ def init_vision_distributed_env(kvargs):
device_id = kvargs["device_id"]
set_current_device_id(device_id)
torch.cuda.set_device(device_id)
+ # 不要在init_process_group时,显示的传入device_id
+ # 这会触发torch的device-bound split优化,会默认后面想加入新进程组的rank
+ # 都已经存在于默认组,这样RL更新weight的init_group时,外部想加入的组,在执行
+ # 通信原语时例如all_reduce,会永远等不到LightLLM默认组里的回复,从而导致错误结果。
dist.init_process_group(
"nccl",
init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}',
rank=kvargs["tp_rank_id"],
world_size=tp_world_size,
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -150,7 +153,6 @@ def init_distributed_env(kvargs):
init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}',
rank=kvargs["rank_id"],
world_size=kvargs["world_size"],
- device_id=torch.device(f"cuda:{device_id}"),
)
# warmup nccl communicator
_a = torch.zeros([1]).to(f"cuda:{device_id}")
@@ -316,3 +318,71 @@ def _init_nccl_env():
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
return
+
+
+# copy from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py#L1675
+def init_custom_process_group(
+ backend=None,
+ init_method=None,
+ timeout=None,
+ world_size=-1,
+ rank=-1,
+ store=None,
+ group_name=None,
+ pg_options=None,
+ device_id=None,
+):
+ from torch.distributed.distributed_c10d import (
+ Backend,
+ PrefixStore,
+ _new_process_group_helper,
+ _world,
+ default_pg_timeout,
+ rendezvous,
+ )
+
+ assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
+
+ if store is not None:
+ assert world_size > 0, "world_size must be positive if using store"
+ assert rank >= 0, "rank must be non-negative if using store"
+ elif init_method is None:
+ init_method = "env://"
+
+ if backend:
+ backend = Backend(backend)
+ else:
+ backend = Backend("undefined")
+
+ if timeout is None:
+ timeout = default_pg_timeout
+
+ # backward compatible API
+ if store is None:
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+
+ # Use a PrefixStore to avoid accidental overrides of keys used by
+ # different systems (e.g. RPC) in case the store is multi-tenant.
+ store = PrefixStore(group_name, store)
+
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
+ # We need to determine the appropriate parameter name based on PyTorch version
+ pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
+ pg, _ = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ **{pg_options_param_name: pg_options},
+ timeout=timeout,
+ device_id=device_id,
+ )
+
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
+
+ return pg
diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py
index b87096d945..c3a466191d 100644
--- a/lightllm/utils/net_utils.py
+++ b/lightllm/utils/net_utils.py
@@ -1,45 +1,92 @@
import socket
import subprocess
import ipaddress
-import random
+import os
from lightllm.utils.log_utils import init_logger
logger = init_logger(__name__)
+DEFAULT_BASE_PORT = 10000
+PORTS_PER_INSTANCE = 1000
+MAX_INSTANCE_ID = 7
-def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000):
- port_list = []
- for port in range(from_port_num, 65536):
+
+def _is_port_available(port: int) -> bool:
+ """Check if a port is available by attempting to bind it."""
+ try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- result = s.connect_ex(("localhost", port))
- if result != 0 and port not in used_ports:
- port_list.append(port)
- if len(port_list) > num * 30:
- break
+ s.bind(("", port))
+ return True
+ except OSError:
+ return False
- if len(port_list) < num:
- return None
- random.shuffle(port_list)
- return port_list[0:num]
+def alloc_can_use_network_port(num=3, instance_id=0, used_ports=None):
+ """
+ Allocate available network ports within an instance-specific range.
+
+ Each instance gets a dedicated 1000-port range starting from BASE_PORT
+ (default 10000, override via LIGHTLLM_BASE_PORT env var).
+ Instance 0: 10000-10999, Instance 1: 11000-11999, etc.
+ """
+ if instance_id < 0 or instance_id > MAX_INSTANCE_ID:
+ raise ValueError(f"instance_id must be in range [0, {MAX_INSTANCE_ID}], got {instance_id}")
+
+ base_port = int(os.environ.get("LIGHTLLM_BASE_PORT", DEFAULT_BASE_PORT))
+ range_start = base_port + instance_id * PORTS_PER_INSTANCE
+ range_end = range_start + PORTS_PER_INSTANCE
+ used_set = set(used_ports) if used_ports else set()
+
+ port_list = []
+ for port in range(range_start, range_end):
+ if len(port_list) >= num:
+ break
+ if port in used_set:
+ continue
+ if _is_port_available(port):
+ port_list.append(port)
+ used_set.add(port)
+
+ if len(port_list) >= num:
+ logger.info(
+ f"Instance {instance_id}: allocated {len(port_list)} ports in [{range_start}, {range_end}): {port_list}"
+ )
+ return port_list
+
+ raise RuntimeError(
+ f"Failed to allocate {num} ports for instance {instance_id} in range [{range_start}, {range_end}). "
+ f"Only found {len(port_list)} available. Try a different instance_id or set LIGHTLLM_BASE_PORT."
+ )
def alloc_can_use_port(min_port, max_port):
port_list = []
for port in range(min_port, max_port):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- result = s.connect_ex(("localhost", port))
+ try:
+ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ result = test_socket.connect_ex(("localhost", port))
+ test_socket.close()
+
if result != 0:
port_list.append(port)
+ except Exception:
+ continue
return port_list
def find_available_port(start_port, end_port):
for port in range(start_port, end_port + 1):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- result = sock.connect_ex(("localhost", port))
+ try:
+ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ result = test_socket.connect_ex(("localhost", port))
+ test_socket.close()
+
if result != 0:
return port
+ except Exception:
+ continue
return None
diff --git a/lightllm/utils/patch_torch.py b/lightllm/utils/patch_torch.py
new file mode 100644
index 0000000000..9f51edeb64
--- /dev/null
+++ b/lightllm/utils/patch_torch.py
@@ -0,0 +1,63 @@
+# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/patch_torch.py
+from typing import Callable, Union
+
+import torch
+from packaging import version
+from torch.multiprocessing import reductions
+
+
+def monkey_patch_torch_reductions():
+ """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
+
+ # Currently, NPU does not support UUID. This has been temporarily commented out,
+ # with support expected in the fourth quarter.
+ # if _is_npu:
+ # return
+
+ if hasattr(reductions, "_reduce_tensor_original"):
+ return
+
+ reductions._reduce_tensor_original = reductions.reduce_tensor
+ reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor
+
+ reductions.reduce_tensor = _reduce_tensor_modified
+ reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified
+
+ reductions.init_reductions()
+
+
+# The signature has not been changed for years, and we will not need this when the next version is released,
+# so it looks safe to use a constant.
+_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6
+
+
+def _reduce_tensor_modified(*args, **kwargs):
+ output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs)
+ output_args = _modify_tuple(output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid)
+ return output_fn, output_args
+
+
+def _rebuild_cuda_tensor_modified(*args):
+ args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid)
+ return reductions._rebuild_cuda_tensor_original(*args)
+
+
+def _device_to_uuid(device: int) -> str:
+ return str(torch.cuda.get_device_properties(device).uuid)
+
+
+def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
+ if isinstance(device_maybe_uuid, int):
+ return device_maybe_uuid
+
+ if isinstance(device_maybe_uuid, str):
+ for device in range(torch.cuda.device_count()):
+ if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid:
+ return device
+ raise Exception("Invalid device_uuid=" + device_maybe_uuid)
+
+ raise Exception(f"Unknown type: {device_maybe_uuid=}")
+
+
+def _modify_tuple(t, index: int, modifier: Callable):
+ return *t[:index], modifier(t[index]), *t[index + 1 :]
diff --git a/lightllm/utils/serializer.py b/lightllm/utils/serializer.py
new file mode 100644
index 0000000000..d8180aeb0c
--- /dev/null
+++ b/lightllm/utils/serializer.py
@@ -0,0 +1,131 @@
+# copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils/common.py
+
+import base64
+import pickle
+import io
+from dataclasses import dataclass
+from multiprocessing.reduction import ForkingPickler
+from typing import List
+
+
+class MultiprocessingSerializer:
+ @staticmethod
+ def serialize(obj, output_str: bool = False):
+ """
+ Serialize a Python object using ForkingPickler.
+
+ Args:
+ obj: The object to serialize.
+ output_str (bool): If True, return a base64-encoded string instead of raw bytes.
+
+ Returns:
+ bytes or str: The serialized object.
+ """
+ buf = io.BytesIO()
+ ForkingPickler(buf).dump(obj)
+ buf.seek(0)
+ output = buf.read()
+
+ if output_str:
+ # Convert bytes to base64-encoded string
+ output = base64.b64encode(output).decode("utf-8")
+
+ return output
+
+ @staticmethod
+ def deserialize(data):
+ """
+ Deserialize a previously serialized object.
+
+ Args:
+ data (bytes or str): The serialized data, optionally base64-encoded.
+
+ Returns:
+ The deserialized Python object.
+ """
+ if isinstance(data, str):
+ # Decode base64 string to bytes
+ data = base64.b64decode(data, validate=True)
+
+ return SafeUnpickler(io.BytesIO(data)).load()
+
+
+class SafeUnpickler(pickle.Unpickler):
+ ALLOWED_MODULE_PREFIXES = {
+ # --- Python types ---
+ "builtins.",
+ "collections.",
+ "copyreg.",
+ "functools.",
+ "itertools.",
+ "operator.",
+ "types.",
+ "weakref.",
+ # --- PyTorch types ---
+ "torch.",
+ "torch._tensor.",
+ "torch.storage.",
+ "torch.nn.parameter.",
+ "torch.autograd.function.",
+ # --- torch distributed ---
+ "torch.distributed.",
+ "torch.distributed._shard.",
+ "torch.distributed._composable.",
+ "torch._C._distributed_c10d.",
+ "torch._C._distributed_fsdp.",
+ "torch.distributed.optim.",
+ # --- multiprocessing ---
+ "multiprocessing.resource_sharer.",
+ "multiprocessing.reduction.",
+ "pickletools.",
+ # --- PEFT / LoRA ---
+ "peft.",
+ "transformers.",
+ "huggingface_hub.",
+ # --- SGLang & Unitest ---
+ "sglang.srt.weight_sync.tensor_bucket.",
+ "sglang.srt.model_executor.model_runner.",
+ "sglang.srt.layers.",
+ "sglang.srt.utils.",
+ # --- LightLLM ---
+ "lightllm.utils.",
+ }
+
+ DENY_CLASSES = {
+ ("builtins", "eval"),
+ ("builtins", "exec"),
+ ("builtins", "compile"),
+ ("os", "system"),
+ ("subprocess", "Popen"),
+ ("subprocess", "run"),
+ ("codecs", "decode"),
+ ("types", "CodeType"),
+ ("types", "FunctionType"),
+ }
+
+ def find_class(self, module, name):
+ # Block deterministic attacks
+ if (module, name) in self.DENY_CLASSES:
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+ # Allowlist of safe-to-load modules.
+ if any((module + ".").startswith(prefix) for prefix in self.ALLOWED_MODULE_PREFIXES):
+ return super().find_class(module, name)
+
+ # Block everything else. (Potential attack surface)
+ raise RuntimeError(
+ f"Blocked unsafe class loading ({module}.{name}), " f"to prevent exploitation of CVE-2025-10164"
+ )
+
+
+@dataclass
+class LocalSerializedTensor:
+ """torch.Tensor that gets serialized by MultiprocessingSerializer
+ (which only serializes a pointer and not the data).
+ The i-th element in the list corresponds to i-th rank's GPU."""
+
+ values: List[bytes]
+
+ def get(self, rank: int):
+ return MultiprocessingSerializer.deserialize(self.values[rank])
diff --git a/lightllm/utils/tensor_bucket.py b/lightllm/utils/tensor_bucket.py
new file mode 100644
index 0000000000..a9d7a367dd
--- /dev/null
+++ b/lightllm/utils/tensor_bucket.py
@@ -0,0 +1,104 @@
+# copy from
+# https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/python/sglang/
+# srt/weight_sync/tensor_bucket.py
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import torch
+
+
+@dataclass
+class FlattenedTensorMetadata:
+ """Metadata for a tensor in a flattened bucket"""
+
+ name: str
+ shape: torch.Size
+ dtype: torch.dtype
+ start_idx: int
+ end_idx: int
+ numel: int
+
+
+class FlattenedTensorBucket:
+ """
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
+ while preserving all metadata needed for reconstruction.
+ """
+
+ # This field is solely for users of to check whether the class supports this feature
+ supports_multi_dtypes = True
+
+ def __init__(
+ self,
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
+ flattened_tensor: torch.Tensor = None,
+ metadata: List[FlattenedTensorMetadata] = None,
+ ):
+ """
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
+ Args:
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
+ metadata: Pre-computed metadata (for reconstruction)
+ """
+ if named_tensors is not None:
+ # Create bucket from named tensors
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
+ self.flattened_tensor: torch.Tensor = None
+
+ if not named_tensors:
+ raise ValueError("Cannot create empty tensor bucket")
+
+ # Collect metadata and flatten tensors
+ current_idx = 0
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
+
+ for i, (name, tensor) in enumerate(named_tensors):
+ flattened = tensor.flatten().view(torch.uint8)
+ flattened_tensors[i] = flattened
+
+ # Store metadata
+
+ numel = flattened.numel()
+ metadata_obj = FlattenedTensorMetadata(
+ name=name,
+ shape=tensor.shape,
+ dtype=tensor.dtype,
+ start_idx=current_idx,
+ end_idx=current_idx + numel,
+ numel=numel,
+ )
+ self.metadata[i] = metadata_obj
+ current_idx += numel
+
+ # Concatenate all flattened tensors
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
+ else:
+ # Initialize from pre-flattened data
+ if flattened_tensor is None or metadata is None:
+ raise ValueError("Must provide either named_tensors or both flattened_tensor and metadata")
+ self.flattened_tensor = flattened_tensor
+ self.metadata = metadata
+
+ def get_flattened_tensor(self) -> torch.Tensor:
+ """Get the flattened tensor containing all bucket tensors"""
+ return self.flattened_tensor
+
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
+ """Get metadata for all tensors in the bucket"""
+ return self.metadata
+
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
+ """
+ Reconstruct original tensors from flattened tensor with optimized performance.
+ Uses memory-efficient operations to minimize allocations and copies.
+ """
+ # preallocate the result list
+ reconstructed = [None] * len(self.metadata)
+
+ for i, meta in enumerate(self.metadata):
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].view(meta.dtype).reshape(meta.shape)
+
+ reconstructed[i] = (meta.name, tensor)
+
+ return reconstructed
diff --git a/lightllm/utils/torch_memory_saver_utils.py b/lightllm/utils/torch_memory_saver_utils.py
new file mode 100644
index 0000000000..c1184ef30c
--- /dev/null
+++ b/lightllm/utils/torch_memory_saver_utils.py
@@ -0,0 +1,92 @@
+import torch
+from contextlib import contextmanager
+from enum import Enum
+from lightllm.utils.log_utils import init_logger
+
+try:
+ from torch_memory_saver import (
+ torch_memory_saver,
+ configure_subprocess,
+ )
+
+ HAS_TORCH_MEMORY_SAVER = True
+
+except ImportError:
+ HAS_TORCH_MEMORY_SAVER = False
+ pass
+
+logger = init_logger(__name__)
+
+
+class MemoryTag(Enum):
+ KV_CACHE = "kv_cache"
+ WEIGHT = "weights"
+ GRAPH = "graph"
+
+ def is_kv_cache(self):
+ return self == MemoryTag.KV_CACHE
+
+ def is_weight(self):
+ return self == MemoryTag.WEIGHT
+
+ def is_graph(self):
+ return self == MemoryTag.GRAPH
+
+ def __str__(self):
+ return self.value
+
+
+class TorchMemorySaverWrapper:
+ def __new__(cls, enable_torch_memory_saver: bool = False):
+ if enable_torch_memory_saver:
+ assert (
+ HAS_TORCH_MEMORY_SAVER
+ ), "torch_memory_saver is not installed, please install it via `pip install torch_memory_saver`."
+ return _TorchMemorySaver()
+ else:
+ return _TorchMemorySaverFake()
+
+
+class _TorchMemorySaver:
+ def configure_subprocess(self):
+ return configure_subprocess()
+
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ return torch_memory_saver.region(tag=tag.value, enable_cpu_backup=enable_cpu_backup)
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch_memory_saver.cuda_graph(cuda_graph=graph_obj, **kwargs, tag=MemoryTag.GRAPH.value)
+
+ def disable(self):
+ return torch_memory_saver.disable()
+
+ def pause(self, tag: MemoryTag):
+ return torch_memory_saver.pause(tag=tag.value)
+
+ def resume(self, tag: MemoryTag):
+ return torch_memory_saver.resume(tag=tag.value)
+
+
+class _TorchMemorySaverFake:
+ @contextmanager
+ def configure_subprocess(self):
+ yield
+
+ @contextmanager
+ def region(self, tag: MemoryTag, enable_cpu_backup: bool = False):
+ yield
+
+ def cuda_graph(self, graph_obj: torch.cuda.CUDAGraph, **kwargs):
+ return torch.cuda.graph(graph_obj, **kwargs)
+
+ @contextmanager
+ def disable(self):
+ yield
+
+ def pause(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, pause is not supported.")
+ return
+
+ def resume(self, tag: MemoryTag):
+ logger.warning("torch_memory_saver is not enabled, resume is not supported.")
+ return
diff --git a/requirements.txt b/requirements.txt
index d37ae05690..2ede1b24c1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -94,6 +94,7 @@ partial_json_parser==0.2.1.1.post6
websockets==15.0.1
cupy-cuda12x==13.6.0
nixl==0.8.0
+torch_memory_saver==0.0.9
xformers==0.0.33.post2
redis==7.3.0
litellm>=1.52.0,<1.85
diff --git a/test/test_api/test_r3.py b/test/test_api/test_r3.py
new file mode 100644
index 0000000000..85c4e44ef9
--- /dev/null
+++ b/test/test_api/test_r3.py
@@ -0,0 +1,92 @@
+import sys
+import argparse
+import requests
+import base64
+import numpy as np
+
+
+def test_routing_export(url: str = "http://localhost:8000"):
+ print(f"Testing routing export at {url}")
+ print("-" * 50)
+
+ try:
+ response = requests.post(
+ f"{url}/generate",
+ json={
+ "inputs": "What is the capital of France? What is the capital of France?",
+ "parameters": {
+ "max_new_tokens": 50,
+ # "return_routed_experts": True,
+ # "repetition_penalty": 1.0,
+ },
+ },
+ timeout=60,
+ )
+ except requests.exceptions.ConnectionError:
+ print(f"ERROR: Cannot connect to server at {url}")
+ print("Make sure the LightLLM server is running with --enable_return_routed_experts")
+ return False
+ except requests.exceptions.Timeout:
+ print("ERROR: Request timed out")
+ return False
+
+ print(f"Status: {response.status_code}")
+
+ if response.status_code != 200:
+ print(f"ERROR: Request failed with status {response.status_code}")
+ print(f"Response: {response.text}")
+ return False
+
+ res = response.json()
+ print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...")
+
+ if "routed_experts" not in res or not res["routed_experts"]:
+ print("\nWARNING: No routed_experts in response.")
+ print("This could mean:")
+ print(" - The model is not a MoE model")
+ print(" - The server was not started with --enable_return_routed_experts")
+ print(" - The routing capture manager was not initialized")
+ return False
+
+ routing_info = res["routed_experts"]
+ shape = routing_info["shape"]
+ dtype_str = routing_info["dtype"]
+ dtype = np.dtype(dtype_str)
+ data = base64.b64decode(routing_info["data"])
+ routing_array = np.frombuffer(data, dtype=dtype).reshape(shape)
+
+ print(f"\n{'=' * 50}")
+ print("ROUTING CAPTURE SUCCESS!")
+ print(f"{'=' * 50}")
+ print(f"Shape: {shape}")
+ print(f"Dtype: {dtype}")
+ print(f"Num tokens: {shape[0]}")
+ print(f"Num MoE layers: {shape[1]}")
+ print(f"Top-K: {shape[2]}")
+
+ # Compute payload size savings
+ int32_size = np.prod(shape) * 4
+ actual_size = len(data)
+ savings = (1 - actual_size / int32_size) * 100
+ print(f"Payload: {actual_size} bytes (vs {int32_size} bytes with int32, {savings:.0f}% smaller)")
+
+ print(f"\nSample routing (first layer, first 5 tokens):")
+ num_tokens_to_show = shape[0]
+ for i in range(num_tokens_to_show):
+ print(f" Token {i}: experts {routing_array[i, 0, :].tolist()}")
+
+ if np.all(routing_array == 0):
+ print("\nWARNING: All routing data is zeros. Capture may not be working correctly.")
+ return False
+
+ print("\nTest PASSED!")
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Test R3 routing export feature")
+ parser.add_argument("--url", default="http://localhost:8000", help="Server URL")
+ args = parser.parse_args()
+
+ success = test_routing_export(args.url)
+ sys.exit(0 if success else 1)
diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py
new file mode 100644
index 0000000000..dcc010b372
--- /dev/null
+++ b/unit_tests/common/basemodel/test_routing_capture_manager.py
@@ -0,0 +1,219 @@
+import torch
+import numpy as np
+
+
+class TestRoutingCaptureManager:
+ def test_capture_and_extract_basic(self):
+ """Test the core pipeline: capture → flush_to_kv_buffer → extract_from_gpu."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=64,
+ )
+
+ # Simulate a batch of 10 tokens at KV-cache positions [100..109]
+ mem_indexes = torch.arange(100, 110, device="cuda")
+
+ # Capture routing for each MoE layer (writes to capture buffer)
+ for layer_idx in range(4):
+ topk_ids = torch.randint(0, 64, (10, 8), device="cuda")
+ manager.capture(moe_layer_index=layer_idx, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush from capture buffer to KV-indexed gpu_kv_buffer
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=10, microbatch_index=0)
+
+ # Extract for those same KV-cache positions
+ result = manager.extract_from_gpu(mem_indexes)
+ assert result.shape == (4, 10, 8)
+ assert result.dtype == np.int8
+
+ def test_capture_writes_to_correct_kv_positions(self):
+ """Verify that captured data lands in the right KV-cache positions after flush."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=2,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+
+ # Use non-contiguous mem_indexes to simulate real KV-cache
+ mem_indexes = torch.tensor([10, 50, 200], device="cuda")
+
+ topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ topk_ids_layer1 = topk_ids + 20
+ manager.capture(moe_layer_index=1, topk_ids=topk_ids_layer1, microbatch_index=0)
+
+ # Flush to KV positions
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0)
+
+ # Extract and verify
+ result = manager.extract_from_gpu(mem_indexes)
+ assert result.shape == (2, 3, 4)
+ np.testing.assert_array_equal(result[0], topk_ids.cpu().numpy())
+ np.testing.assert_array_equal(result[1], topk_ids_layer1.cpu().numpy())
+
+ def test_microbatch_isolation(self):
+ """Two microbatches writing to different KV positions don't interfere."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=32,
+ kv_cache_size=256,
+ max_capture_tokens=16,
+ )
+
+ # Microbatch 0: positions [10, 11]
+ mem0 = torch.tensor([10, 11], device="cuda")
+ ids_0 = torch.ones((2, 4), dtype=torch.int64, device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0)
+
+ # Microbatch 1: positions [20, 21]
+ mem1 = torch.tensor([20, 21], device="cuda")
+ ids_1 = torch.ones((2, 4), dtype=torch.int64, device="cuda") * 2
+ manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1)
+
+ # Flush each microbatch to different KV positions
+ manager.flush_to_kv_buffer(mem0, num_tokens=2, microbatch_index=0)
+ manager.flush_to_kv_buffer(mem1, num_tokens=2, microbatch_index=1)
+
+ # Extract microbatch 0
+ result0 = manager.extract_from_gpu(mem0)
+ assert result0.shape == (1, 2, 4)
+ assert result0[0, 0, 0] == 1
+
+ # Extract microbatch 1
+ result1 = manager.extract_from_gpu(mem1)
+ assert result1[0, 0, 0] == 2
+
+ def test_dtype_selection_int8(self):
+ """Models with ≤127 experts use int8."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=64,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.int8
+ assert manager.np_dtype == np.int8
+ assert manager.dtype_id == 1
+
+ def test_dtype_selection_int16(self):
+ """Models with >127 experts use int16."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=256,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+ assert manager.dtype == torch.int16
+ assert manager.np_dtype == np.int16
+ assert manager.dtype_id == 2
+
+ def test_extract_preserves_values(self):
+ """Extracted values exactly match what was captured, no dtype truncation."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=4,
+ num_experts=64,
+ kv_cache_size=64,
+ max_capture_tokens=16,
+ )
+
+ mem_indexes = torch.tensor([0, 1, 2], device="cuda")
+
+ topk_ids = torch.tensor([[10, 20, 30, 40], [50, 60, 63, 1], [0, 5, 127, 3]], device="cuda")
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush then extract
+ manager.flush_to_kv_buffer(mem_indexes, num_tokens=3, microbatch_index=0)
+ result = manager.extract_from_gpu(mem_indexes)
+ expected = topk_ids.cpu().numpy().astype(np.int8)
+ np.testing.assert_array_equal(result[0], expected)
+
+ def test_gpu_kv_buffer_shape(self):
+ """Buffer shape is (num_moe_layers, kv_cache_size, topk)."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ # 127 experts fits in int8 (max value 127)
+ manager = RoutingCaptureManager(
+ num_moe_layers=48,
+ topk=8,
+ num_experts=127,
+ kv_cache_size=2048,
+ max_capture_tokens=256,
+ )
+ assert manager.gpu_kv_buffer.shape == (48, 2048, 8)
+ assert manager.gpu_kv_buffer.dtype == torch.int8
+ assert manager.gpu_kv_buffer.device.type == "cuda"
+
+ # 128 experts requires int16
+ manager2 = RoutingCaptureManager(
+ num_moe_layers=48,
+ topk=8,
+ num_experts=128,
+ kv_cache_size=2048,
+ max_capture_tokens=256,
+ )
+ assert manager2.gpu_kv_buffer.dtype == torch.int16
+
+ def test_partial_token_capture(self):
+ """capture() only writes num_tokens rows to the buffer."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=1,
+ topk=2,
+ num_experts=32,
+ kv_cache_size=128,
+ max_capture_tokens=16,
+ )
+
+ # Capture only 3 tokens, flush to 5 KV positions (first 3 get data)
+ mem_indexes = torch.tensor([10, 11, 12, 13, 14], device="cuda")
+
+ topk_ids = torch.tensor([[1, 2], [3, 4], [5, 6]], device="cuda") # only 3 tokens
+ manager.capture(moe_layer_index=0, topk_ids=topk_ids, microbatch_index=0)
+
+ # Flush only the 3 captured tokens
+ manager.flush_to_kv_buffer(mem_indexes[:3], num_tokens=3, microbatch_index=0)
+
+ # Positions 10-12 should have data, 13-14 should be zeros (from init)
+ result_written = manager.extract_from_gpu(mem_indexes[:3])
+ np.testing.assert_array_equal(result_written[0], topk_ids.cpu().numpy().astype(np.int8))
+
+ result_unwritten = manager.extract_from_gpu(mem_indexes[3:])
+ np.testing.assert_array_equal(result_unwritten[0], np.zeros((2, 2), dtype=np.int8))
+
+ def test_capture_buffer_shape(self):
+ """Capture buffer has correct shape (max_tokens, num_moe_layers, topk)."""
+ from lightllm.common.basemodel.routing_manager import RoutingCaptureManager
+
+ manager = RoutingCaptureManager(
+ num_moe_layers=4,
+ topk=8,
+ num_experts=64,
+ kv_cache_size=1024,
+ max_capture_tokens=256,
+ )
+ assert manager._capture_buffer[0].shape == (256, 4, 8)
+ assert manager._capture_buffer[1].shape == (256, 4, 8)
+ assert manager._capture_buffer[0].dtype == torch.int8
diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py
new file mode 100644
index 0000000000..3b2f159f62
--- /dev/null
+++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py
@@ -0,0 +1,50 @@
+import pytest
+import torch
+
+from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import (
+ apply_invalid_token_ids,
+)
+
+
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
+def test_apply_invalid_token_ids(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for Triton kernels.")
+
+ batch_size = 4
+ vocab_size = 32
+ logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype)
+ expected = logits.clone()
+
+ invalid_token_ids_per_batch = [
+ [1, 3, 5],
+ [],
+ [0, 2, 31],
+ [7],
+ ]
+
+ flat_ids = []
+ cu_invalid_token_num = [0]
+ invalid_token_num_start = 0
+ for ids in invalid_token_ids_per_batch:
+ flat_ids.extend(ids)
+ invalid_token_num_start += len(ids)
+ cu_invalid_token_num.append(invalid_token_num_start)
+
+ invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32)
+ cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32)
+
+ for batch_idx, ids in enumerate(invalid_token_ids_per_batch):
+ if ids:
+ expected[batch_idx, ids] = float("-inf")
+
+ apply_invalid_token_ids(
+ Logits=logits,
+ invalid_token_ids=invalid_token_ids,
+ cu_invalid_token_num=cu_invalid_token_num,
+ )
+ assert torch.equal(logits, expected)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
index 605433e9d8..dfeda0b6f7 100644
--- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
+++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py
@@ -230,5 +230,32 @@ def test_case9():
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
+def test_case10():
+ """
+ 测试场景:测试 flush_cache 函数
+ """
+ print("\nTest Case 10: Testing flush_cache function\n")
+ tree = RadixCache("unique_name", 100, 0)
+ tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
+ tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is not None
+ assert size == 3
+ tree.flush_cache()
+ tree_node, size, values = tree.match_prefix(
+ torch.tensor([1, 2, 3], dtype=torch.int64, device="cpu"), update_refs=True
+ )
+ assert tree_node is None
+ assert size == 0
+ assert tree.get_tree_total_tokens_num() == 0
+ assert tree.get_refed_tokens_num() == 0
+ assert len(tree.root_node.children) == 0
+ assert tree.root_node.token_id_key.numel() == 0
+ assert tree.root_node.token_mem_index_value.numel() == 0
+ assert tree.root_node.ref_counter == 1
+
+
if __name__ == "__main__":
pytest.main()