diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index f00d572d9..bdd86eb51 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -46,7 +46,8 @@ def _select_experts( num_expert_group=num_expert_group, scoring_func=scoring_func, ) - topk_weights.mul_(self.routed_scaling_factor) + if self.routed_scaling_factor != 1.0: + topk_weights.mul_(self.routed_scaling_factor) if self.redundancy_expert_num > 0: redundancy_topk_ids_repair( topk_ids=topk_ids, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf9..d6e923a11 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -57,7 +57,8 @@ def _select_experts( num_expert_group=num_expert_group, scoring_func=scoring_func, ) - topk_weights.mul_(self.routed_scaling_factor) + if self.routed_scaling_factor != 1.0: + topk_weights.mul_(self.routed_scaling_factor) if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 72c3a381e..59d1f825a 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -46,7 +46,7 @@ def fused_topk( sgl_ops.topk_softmax( topk_weights, topk_ids, - gating_output.float(), # TODO(woosuk): Optimize this. + gating_output, renormalize=renormalize, ) return topk_weights, topk_ids diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index 40322e509..d3fc1b7db 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -64,3 +64,139 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): num_warps=4, ) return x + + +@triton.jit +def _qk_rms_norm_fused_kernel( + # Q Pointers & Strides + Q_ptr, + WQ_ptr, + stride_q_row, + stride_q_col, + # K Pointers & Strides + K_ptr, + WK_ptr, + stride_k_row, + stride_k_col, + # Dimensions + num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界) + head_dim: tl.constexpr, + eps, + BLOCK_SIZE: tl.constexpr, +): + # PID 0: 处理第几个 Token (Row) + row_idx = tl.program_id(0) + # PID 1: 处理第几个 Head (Combo Index) + # 范围是 [0, num_heads_q + num_heads_k) + combo_head_idx = tl.program_id(1) + + # 公共的 offset (0 ~ head_dim) + offs = tl.arange(0, BLOCK_SIZE) + + # === 分支逻辑:判断是处理 Q 还是 K === + if combo_head_idx < num_heads_q: + # ------------------ 处理 Q ------------------ + # 指针计算 + # Q 的实际 head index 就是 combo_head_idx + Q_ptr += row_idx * stride_q_row + + # 定位 Q 数据: Base + Row偏移 + Head偏移 + 列偏移 + q_ptr_offset = (combo_head_idx * head_dim + offs) * stride_q_col + + # 加载 Q 数据 + x = tl.load(Q_ptr + q_ptr_offset).to(tl.float32) + # RMSNorm 计算 + var = tl.sum(x * x, axis=0) / head_dim + rstd = 1 / tl.sqrt(var + eps) + + # 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重) + w = tl.load(WQ_ptr + offs).to(tl.float32) + + y = x * rstd * w + + # 写回 Q + tl.store(Q_ptr + q_ptr_offset, y.to(Q_ptr.dtype.element_ty)) + + else: + # ------------------ 处理 K ------------------ + # 重新映射 K 的 head index (从 0 开始) + k_head_idx = combo_head_idx - num_heads_q + + # 指针计算 + K_ptr += row_idx * stride_k_row + k_ptr_offset = (k_head_idx * head_dim + offs) * stride_k_col + + # 加载 K 数据 + x = tl.load(K_ptr + k_ptr_offset).to(tl.float32) + # RMSNorm 计算 + var = tl.sum(x * x, axis=0) / head_dim + rstd = 1 / tl.sqrt(var + eps) + + # 加载 K 的权重 + w = tl.load(WK_ptr + offs).to(tl.float32) + + y = x * rstd * w + + # 写回 K + tl.store(K_ptr + k_ptr_offset, y.to(K_ptr.dtype.element_ty)) + + +def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6): + """ + In-place RMSNorm for both Q and K in a single kernel launch. + Supports GQA (different number of heads for Q and K). + + Args: + q: (Total_Tokens, Hidden_Q) or (B, S, H_q, D) -> flattend to 2D inside + k: (Total_Tokens, Hidden_K) + w_q: (head_dim,) Scale parameter for Q + w_k: (head_dim,) Scale parameter for K + """ + # 1. 维度与连续性检查 + # 将输入统一视为 (Total_Tokens, Hidden_Size) 的 2D 视图 + q_view = q.view(-1, q.shape[-1]) + k_view = k.view(-1, k.shape[-1]) + + assert w_q.is_contiguous() and w_k.is_contiguous() + + M = q_view.shape[0] # Total Tokens + assert k_view.shape[0] == M, "Q and K must have the same number of tokens" + + head_dim = w_q.shape[0] + assert w_k.shape[0] == head_dim, "Head dim of Q and K must match" + + # 计算 Head 数量 + N_q = q_view.shape[1] + N_k = k_view.shape[1] + + assert N_q % head_dim == 0 + assert N_k % head_dim == 0 + + num_heads_q = N_q // head_dim + num_heads_k = N_k // head_dim + + # 2. Block Size 设置 + BLOCK_SIZE = triton.next_power_of_2(head_dim) + assert BLOCK_SIZE == head_dim, "Currently only supports head_dim power of 2 (e.g., 64, 128)" + + # 3. 启动 Kernel + # Grid: (Token数量, Q头数 + K头数) + grid = (M, num_heads_q + num_heads_k) + + _qk_rms_norm_fused_kernel[grid]( + q_view, + w_q, + q_view.stride(0), + q_view.stride(1), + k_view, + w_k, + k_view.stride(0), + k_view.stride(1), + num_heads_q=num_heads_q, + head_dim=head_dim, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=4, + ) + + return q, k diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index c8100c676..519fd497f 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -17,6 +17,15 @@ "num_stages": 3, "num_warps": 4 }, + "192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "2048": { "BLOCK_SIZE_K": 32, "BLOCK_SIZE_M": 32, @@ -35,6 +44,15 @@ "num_stages": 2, "num_warps": 4 }, + "384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "512": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, @@ -53,6 +71,24 @@ "num_stages": 2, "num_warps": 4 }, + "640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "8": { "BLOCK_SIZE_K": 32, "BLOCK_SIZE_M": 16, @@ -79,5 +115,23 @@ "NEED_TRANS": false, "num_stages": 2, "num_warps": 4 + }, + "896": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json index 4142ee983..26a6d63c4 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -26,6 +26,24 @@ "num_stages": 3, "num_warps": 4 }, + "112": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "12": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, "128": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -44,6 +62,15 @@ "num_stages": 3, "num_warps": 4 }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "256": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 32, @@ -62,6 +89,15 @@ "num_stages": 3, "num_warps": 4 }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, "64": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -79,5 +115,23 @@ "NEED_TRANS": false, "num_stages": 3, "num_warps": 8 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index 002b842cb..ea17f7f5a 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -11,6 +11,14 @@ "BLOCK_SIZE": 256, "num_warps": 4 }, + "112": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "12": { + "BLOCK_SIZE": 512, + "num_warps": 8 + }, "128": { "BLOCK_SIZE": 256, "num_warps": 8 @@ -19,6 +27,14 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "2": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "256": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -27,6 +43,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "48": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "64": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -34,5 +54,13 @@ "8": { "BLOCK_SIZE": 128, "num_warps": 8 + }, + "80": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "96": { + "BLOCK_SIZE": 128, + "num_warps": 8 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json index bc904bb7f..6f5752b8d 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -17,6 +17,18 @@ "NUM_STAGE": 4, "num_warps": 1 }, + "112": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "12": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, "128": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -29,6 +41,12 @@ "NUM_STAGE": 1, "num_warps": 2 }, + "24": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "256": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -41,6 +59,12 @@ "NUM_STAGE": 4, "num_warps": 4 }, + "48": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "64": { "BLOCK_DIM": 128, "BLOCK_M": 1, @@ -52,5 +76,17 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 16 + }, + "80": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "96": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 50499a3e7..37f18f454 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -11,6 +11,12 @@ "NUM_STAGES": 1, "num_warps": 8 }, + "192": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, "2048": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -23,6 +29,12 @@ "NUM_STAGES": 1, "num_warps": 8 }, + "384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, "512": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -35,6 +47,18 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 64, @@ -52,5 +76,17 @@ "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 1 + }, + "896": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "96": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc..391581627 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -5,6 +5,7 @@ import numpy as np import triton from typing import Tuple +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -64,9 +65,11 @@ def _get_qkv( q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - layer_weight.q_norm_weight_(q, eps=self.eps_) - layer_weight.k_norm_weight_( + qk_rmsnorm_fused_forward( + q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + layer_weight.q_norm_weight_.weight, + layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index ca3901ebd..2afa06d38 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -16,6 +16,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_length_penalty_param, b_mask_eos_reqs, is_all_greedy, + is_all_random, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -68,6 +69,11 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) + elif is_all_random: + batch_next_token_ids = _random_sample(probs) + batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) + return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) + elif get_env_start_args().sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) @@ -104,6 +110,12 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _random_sample(probs: torch.Tensor): + q = torch.empty_like(probs) + q.exponential_() + return probs.div(q).argmax(dim=-1).view(-1) + + def _get_post_sample_tensors(reqs: List[InferReq]): req_idxes: List[int] = [] temperatures: List[float] = [] @@ -112,6 +124,7 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param: List[int] = [] mask_eos_reqs: List[bool] = [] is_all_greedy = True + is_all_random = True for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param @@ -127,6 +140,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks.append(top_k_val) if top_k_val > 1: is_all_greedy = False + if top_k_val != -1 or shm_param.top_p != 1.0: + is_all_random = False req_idxes.append(req_obj.req_idx) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) @@ -146,4 +161,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), is_all_greedy, + is_all_random, ) diff --git a/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py new file mode 100644 index 000000000..2210972ca --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py @@ -0,0 +1,52 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import ( + qk_rmsnorm_fused_forward, + qk_rmsnorm_forward, +) + + +def test_qk_rmsnorm_fused_matches_reference(): + """Compare fused QK RMSNorm with separate reference RMSNorm kernels.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for qk_rmsnorm_fused test") + + torch.manual_seed(0) + + # 模拟配置: Batch=2, Seq=128, Head_Dim=128 + # Q: 16 Heads, K: 4 Heads (GQA 场景) + B, S, D = 2, 128, 128 + H_Q = 16 + H_K = 4 + + q = torch.randn((B * S, H_Q * D), device="cuda", dtype=torch.bfloat16) + k = torch.randn((B * S, H_K * D), device="cuda", dtype=torch.bfloat16) + + w_q = torch.ones((D,), device="cuda", dtype=torch.bfloat16) + w_k = torch.ones((D,), device="cuda", dtype=torch.bfloat16) + + # 复制一份做对比(reference 会在新 tensor 上计算) + q_ref = q.clone() + k_ref = k.clone() + + # fused kernel in-place 计算 + q_out, k_out = qk_rmsnorm_fused_forward(q, k, w_q, w_k, eps=1e-6) + + # reference: 分别对 Q / K 做 RMSNorm + q_ref_out = qk_rmsnorm_forward(q_ref, w_q, eps=1e-6) + k_ref_out = qk_rmsnorm_forward(k_ref, w_k, eps=1e-6) + + # fused 是 in-place 的,返回的 q_out/k_out 应该与 q/k 引用一致 + assert q_out.data_ptr() == q.data_ptr() + assert k_out.data_ptr() == k.data_ptr() + + # 误差容忍度: 由于 bfloat16 计算,设定一个合理的 atol + q_max_diff = (q_out - q_ref_out).abs().max().item() + k_max_diff = (k_out - k_ref_out).abs().max().item() + + print(f"Q max diff: {q_max_diff}") + print(f"K max diff: {k_max_diff}") + + assert q_max_diff < 1e-5, f"Q max diff too large: {q_max_diff}" + assert k_max_diff < 1e-5, f"K max diff too large: {k_max_diff}"