diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 3a8807ca5..bb48bfe49 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -7,6 +7,7 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo from lightllm.utils.log_utils import init_logger +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.common.kv_cache_mem_manager import Qwen3NextMemManager from typing import Tuple from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -249,16 +250,13 @@ def gdn_forward( assert isinstance(infer_state.mem_manager, Qwen3NextMemManager) input = input.view(-1, self.embed_dim_) - conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) - mixed_qkvzba = layer_weight.linear_in_proj.mm(input) - mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) if is_prefill: - core_attn_out = self._gdn_prefill_kernel( - mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight - ) + core_attn_out, z = self._gdn_prefill_wrapper_run(mixed_qkvzba, infer_state, layer_weight) else: + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) core_attn_out = self._gdn_decode_kernel( mixed_qkv, conv_states, @@ -277,7 +275,55 @@ def gdn_forward( output = layer_weight.linear_out_proj.mm(core_attn_out) return output - def _split_qkvzba(self, mixed_qkvzba, is_decode=False): + def _gdn_prefill_wrapper_run( + self, + mixed_qkvzba: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if torch.cuda.is_current_stream_capturing(): + mixed_qkvzba = mixed_qkvzba.contiguous() + _mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + # _gdn_prefill_kernel returns the pre-projection value stream. Its + # logical size is num_tokens * local value heads * value head dim. + # We avoid a dry-run because FlashQLA may do host-side syncs while + # preparing varlen chunk metadata, which is illegal during capture. + num_tokens = mixed_qkvzba.shape[0] + o_shape = (num_tokens, self.tp_num_v_heads, self.head_v_dim) + o_dtype = mixed_qkvzba.dtype + o_device = mixed_qkvzba.device + z_shape = o_shape + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + o = torch.empty(o_shape, dtype=o_dtype, device=o_device) + _o = tensor_to_no_ref_tensor(o) + z = torch.empty(z_shape, dtype=o_dtype, device=o_device) + _z = tensor_to_no_ref_tensor(z) + + def gdn_prefill_func(new_infer_state: Qwen3NextInferStateInfo): + conv_states, ssm_states = new_infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, tmp_z, b, a = self._split_qkvzba(_mixed_qkvzba) + _z.copy_(tmp_z) + tmp_o = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, a, b, new_infer_state, layer_weight + ) + tmp_o = tmp_o.view(_o.shape) + _o.copy_(tmp_o) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=gdn_prefill_func, after_graph=pre_capture_graph) + return o, z + + conv_states, ssm_states = infer_state.req_manager.get_mamba_cache(self.layer_num_) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba) + core_attn_out = self._gdn_prefill_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + return core_attn_out, z + + def _split_qkvzba(self, mixed_qkvzba): qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim z_end = qkv_dim + self.tp_value_dim b_end = z_end + self.tp_num_v_heads diff --git a/test/acc/test_qwen3.5.sh b/test/acc/test_qwen3.5.sh index 69169ec23..a590aa083 100644 --- a/test/acc/test_qwen3.5.sh +++ b/test/acc/test_qwen3.5.sh @@ -9,6 +9,15 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ # second export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code +# prefill cuda graph 功能测试 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \ +--model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17 \ +--tp 2 \ +--port 8089 \ +--enable_prefill_cudagraph + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/Qwen3.5-0.8B", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + # 测试 LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server \