From 9c8064f0d7b180bd386179f8f74f16ae4e17b806 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 7 May 2026 02:12:21 +0000 Subject: [PATCH 1/4] support prefill cudagraph --- .../layer_infer/transformer_layer_infer.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 3a8807ca5..66656345d 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -1,3 +1,4 @@ +import gc import torch import torch.distributed as dist @@ -7,6 +8,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -221,11 +223,49 @@ def context_attention_forward( if not self.is_linear_attention_layer: return super().context_attention_forward(input_embdings, infer_state, layer_weight) - gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + gdn_out = self._gdn_wrapper_run(input_embdings, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return gdn_out + def _gdn_wrapper_run( + self, + input_embdings: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + if torch.cuda.is_current_stream_capturing(): + x = input_embdings.contiguous() + _x = tensor_to_no_ref_tensor(x) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + # Output shape mirrors gdn_forward's last step: linear_out_proj maps + # back to hidden_size, so o is (num_tokens, embed_dim_). We hardcode + # this rather than running a dry-run capture because FlashQLA's + # chunk_gated_delta_rule internally calls tilelang.cdiv(...).tolist() + # which requires a host-side sync — illegal during stream capture. + num_tokens = input_embdings.numel() // self.embed_dim_ + o_shape = (num_tokens, self.embed_dim_) + o_dtype = input_embdings.dtype + o_device = input_embdings.device + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty(o_shape, dtype=o_dtype, device=o_device) + _o = tensor_to_no_ref_tensor(o) + + def gdn_func(new_infer_state: Qwen3NextInferStateInfo): + tmp_o = self.gdn_forward(_x, new_infer_state, layer_weight, is_prefill=True) + tmp_o = tmp_o.view(_o.shape) + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_func, after_graph=pre_capture_graph) + return o + + return self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + def token_attention_forward( self, input_embdings, From 3f9cd92c4b8d462efca63f09f947fe7808be6d7c Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 7 May 2026 03:24:22 +0000 Subject: [PATCH 2/4] narrow down the wrapper --- .../layer_infer/transformer_layer_infer.py | 98 ++++++++++--------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 66656345d..4ca71935d 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -1,4 +1,3 @@ -import gc import torch import torch.distributed as dist @@ -223,49 +222,11 @@ def context_attention_forward( if not self.is_linear_attention_layer: return super().context_attention_forward(input_embdings, infer_state, layer_weight) - gdn_out = self._gdn_wrapper_run(input_embdings, infer_state, layer_weight) + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) if self.tp_world_size_ > 1: all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return gdn_out - def _gdn_wrapper_run( - self, - input_embdings: torch.Tensor, - infer_state: Qwen3NextInferStateInfo, - layer_weight: Qwen3NextTransformerLayerWeight, - ) -> torch.Tensor: - if torch.cuda.is_current_stream_capturing(): - x = input_embdings.contiguous() - _x = tensor_to_no_ref_tensor(x) - pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() - pre_capture_graph.__exit__(None, None, None) - - # Output shape mirrors gdn_forward's last step: linear_out_proj maps - # back to hidden_size, so o is (num_tokens, embed_dim_). We hardcode - # this rather than running a dry-run capture because FlashQLA's - # chunk_gated_delta_rule internally calls tilelang.cdiv(...).tolist() - # which requires a host-side sync — illegal during stream capture. - num_tokens = input_embdings.numel() // self.embed_dim_ - o_shape = (num_tokens, self.embed_dim_) - o_dtype = input_embdings.dtype - o_device = input_embdings.device - - infer_state.prefill_cuda_graph_create_graph_obj() - infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() - o = torch.empty(o_shape, dtype=o_dtype, device=o_device) - _o = tensor_to_no_ref_tensor(o) - - def gdn_func(new_infer_state: Qwen3NextInferStateInfo): - tmp_o = self.gdn_forward(_x, new_infer_state, layer_weight, is_prefill=True) - tmp_o = tmp_o.view(_o.shape) - _o.copy_(tmp_o) - return - - infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_func, after_graph=pre_capture_graph) - return o - - return self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) - def token_attention_forward( self, input_embdings, @@ -289,16 +250,13 @@ def gdn_forward( assert isinstance(infer_state.mem_manager, Qwen3NextMemManager) input = input.view(-1, self.embed_dim_) - conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) if is_prefill: - core_attn_out = self._gdn_prefill_kernel( - mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight - ) + core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) else: + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=True) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) core_attn_out = self._gdn_decode_kernel( mixed_qkv, conv_states, @@ -317,6 +275,54 @@ def gdn_forward( output = layer_weight.linear_out_proj.mm(core_attn_out) return output + def _gdn_prefill_wrapper_run( + self, + mixed_qkvzba: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if torch.cuda.is_current_stream_capturing(): + mixed_qkvzba = mixed_qkvzba.contiguous() + _mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + # _gdn_prefill_kernel returns the pre-projection value stream. Its + # logical size is num_tokens * local value heads * value head dim. + # We avoid a dry-run because FlashQLA may do host-side syncs while + # preparing varlen chunk metadata, which is illegal during capture. + num_tokens = mixed_qkvzba.shape[0] + o_shape = (num_tokens, self.tp_num_v_heads, self.head_v_dim) + o_dtype = mixed_qkvzba.dtype + o_device = mixed_qkvzba.device + z_shape = o_shape + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty(o_shape, dtype=o_dtype, device=o_device) + _o = tensor_to_no_ref_tensor(o) + z = torch.empty(z_shape, dtype=o_dtype, device=o_device) + _z = tensor_to_no_ref_tensor(z) + + def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): + conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False) + _z.copy_(tmp_z) + tmp_o = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight + ) + tmp_o = tmp_o.view(_o.shape) + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph) + return o, z + + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=False) + core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + return core_attn_out, z + def _split_qkvzba(self, mixed_qkvzba, is_decode=False): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim z_end = qkv_dim + self.tp_value_dim From e8790b1dbc631b343fd0a47e42fb10055440903f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 8 May 2026 08:06:52 +0000 Subject: [PATCH 3/4] fix --- test/acc/test_qwen3.5.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/acc/test_qwen3.5.sh b/test/acc/test_qwen3.5.sh index 69169ec23..a590aa083 100644 --- a/test/acc/test_qwen3.5.sh +++ b/test/acc/test_qwen3.5.sh @@ -9,6 +9,15 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ # second export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code +# prefill cuda graph 功能测试 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 \ +--enable_prefill_cudagraph + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + # 测试 LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ From cbb8bd7f1fc92cf5f83705f9de3e247d20044f8b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 8 May 2026 09:08:16 +0000 Subject: [PATCH 4/4] fix --- .../qwen3next/layer_infer/transformer_layer_infer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 4ca71935d..bb48bfe49 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -255,7 +255,7 @@ def gdn_forward( if is_prefill: core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) else: - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=True) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) core_attn_out = self._gdn_decode_kernel( mixed_qkv, @@ -306,7 +306,7 @@ def _gdn_prefill_wrapper_run( def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_) - mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba, is_decode=False) + mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba) _z.copy_(tmp_z) tmp_o = self._gdn_prefill_kernel( mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight @@ -319,11 +319,11 @@ def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): return o, z conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=False) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) return core_attn_out, z - def _split_qkvzba(self, mixed_qkvzba, is_decode=False): + def _split_qkvzba(self, mixed_qkvzba): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim z_end = qkv_dim + self.tp_value_dim b_end = z_end + self.tp_num_v_heads