From 669965bb56af1c784246c1befcfe7d92a35b1334 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Sat, 18 Apr 2026 11:36:48 +0000 Subject: [PATCH 1/7] [Ascend] MTP speculative decoding support (qwen3_5_mtp_final_2) - op_backend.py: MTP detection (is_multi_token_decoding), effective_is_decoding, actual_seq_lengths_q, vendor_device_init trigger - attention.py: add is_multi_token_decoding and actual_seq_lengths_q fields - pagedattention.py: MTP verify reuses paged_prefill_attention - config.py: SpecDecodeConfig.from_config add device_type param - config_builder.py: pass device_type to SpecDecodeConfig Co-Authored-By: Claude Opus 4.6 --- .../backends/dlinfer/ascend/op_backend.py | 40 ++++++++++++++----- .../pytorch/backends/dlinfer/attention.py | 4 ++ lmdeploy/pytorch/config.py | 2 + lmdeploy/pytorch/engine/config_builder.py | 1 + .../pytorch/kernels/dlinfer/pagedattention.py | 33 ++++++++++++++- 5 files changed, 68 insertions(+), 12 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 8cc41f5df8..40439165c0 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -156,8 +156,16 @@ def update_step_context(cls, step_context): block_num, block_size, *_ = step_context.kv_caches[0][0].shape is_prefill_no_cache = False + is_multi_token_decoding = False + actual_seq_lengths_q = None if not step_context.is_decoding: is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + else: + # Speculative decoding: main model decodes multiple tokens per sequence + is_multi_token_decoding = torch.max(step_context.q_seqlens).item() > 1 + effective_is_decoding = step_context.is_decoding and not is_multi_token_decoding + if is_multi_token_decoding: + actual_seq_lengths_q = step_context.q_seqlens.cpu().cumsum(0).to(torch.int32) if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) if step_context.kv_seqlens.dtype != torch.int32: @@ -173,19 +181,20 @@ def get_total_slots(): cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - def get_cpu_seqlens(is_decoding, is_prefill_no_cache): + def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): """Get sequence lengths on CPU. Returns: q_seqlens_cpu: query sequence lengths (per sequence). kv_seqlens_cpu: kv sequence lengths (per sequence), used for list/max seqlens calculation. - kv_seqlens_expanded: kv sequence lengths expanded per token via - repeat_interleave, used for attention metadata. """ - if is_decoding: + if is_decoding and not is_multi_token_decoding: q_seqlens_cpu = None kv_seqlens_cpu = step_context.kv_seqlens.cpu() + elif is_multi_token_decoding: + q_seqlens_cpu = step_context.q_seqlens.cpu() + kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_prefill_no_cache: q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = q_seqlens_cpu @@ -353,16 +362,17 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu, + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(effective_is_decoding, is_prefill_no_cache, + is_multi_token_decoding) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list, + max_q_seq_len, max_kv_seq_len = get_max_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list) - kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(effective_is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) - q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu) + q_seqlens_cpu = update_q_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_cpu) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') @@ -385,12 +395,12 @@ def get_moe_group_name(group): q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, device=step_context.q_seqlens.device) cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() - if not step_context.is_decoding: + if not effective_is_decoding: has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( - step_context.is_decoding, + effective_is_decoding, step_context.block_offsets, # cu_seqlens is only used in GDN and is passed down via q_start_loc. # Otherwise, q_start_loc is None. @@ -406,6 +416,8 @@ def get_moe_group_name(group): quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, has_initial_state=has_initial_state, + is_multi_token_decoding=is_multi_token_decoding, + actual_seq_lengths_q=actual_seq_lengths_q, ) step_context.attn_metadata = attn_metadata @@ -462,6 +474,12 @@ def init(): logger.warning(f'Error during Ascend initialization: {str(e)}. ' 'Please check your Ascend environment configuration.') + try: + import dlinfer.framework.lmdeploy_ext.device # noqa: F401 — triggers vendor_device_init() + except ImportError: + logger.warning('dlinfer framework extensions not found. ' + 'Ascend-specific model patches will not be applied.') + try: from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton init_device_properties_triton() diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index a8eea27545..232ced9e59 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -19,6 +19,8 @@ class DlinferAttentionMetadata(AttentionMetadata): quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None has_initial_state: Tensor | None = None + is_multi_token_decoding: bool = False + actual_seq_lengths_q: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -143,6 +145,8 @@ def forward( kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits, + is_multi_token_decoding=attn_metadata.is_multi_token_decoding, + actual_seq_lengths_q=attn_metadata.actual_seq_lengths_q, ) return attn_output diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9b26222d31..6c5c286a84 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -566,6 +566,7 @@ def from_config( target_cache_cfg: CacheConfig, target_model: str = None, dtype: str = 'auto', + device_type: str = 'auto', ): model = model or target_model model_config = ModelConfig.from_pretrained(model, @@ -574,6 +575,7 @@ def from_config( is_draft_model=True, spec_method=method, block_size=target_cache_cfg.block_size, + device_type=device_type, ) cache_config = None # include medusa diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 3c5a005daf..96e886eadf 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -114,5 +114,6 @@ def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, target_model=target_model, target_cache_cfg=cache_config, dtype=engine_config.dtype, + device_type=engine_config.device_type, ) return specdecode_config diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 5e75de0a5b..a913dc024b 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -137,8 +137,39 @@ def paged_attention_fwd( kv_scales: Tensor | None = None, kv_zeros: Tensor | None = None, quant_bits: int | None = 0, + is_multi_token_decoding: bool = False, + actual_seq_lengths_q: Tensor | None = None, ): - if not is_decoding: + if is_multi_token_decoding: + # MTP verify is semantically a "mini-prefill": multiple tokens per + # sequence with TND layout, sparse_mode=3 and causal mask. + # Reuse paged_prefill_attention, passing cumulative q lengths. + return prefill_attention( + query_states, + key_states, + value_states, + attn_output, + key_cache, + value_cache, + block_offsets, + q_start_loc, + actual_seq_lengths_q, + kv_seqlens, + cu_seq_lens_kv, + max_q_seq_len, + max_kv_seq_len, + block_size, + num_heads, + num_kv_heads, + v_head_size, + attn_mask, + softmax_scale, + is_prefill_no_cache=False, + kv_scales=kv_scales, + kv_zeros=kv_zeros, + quant_bits=quant_bits, + ) + elif not is_decoding: return prefill_attention( query_states, key_states, From 85f8bc4d62d51b8dae5080a6cdbc8daf0040fe4f Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Sun, 19 Apr 2026 09:03:02 +0000 Subject: [PATCH 2/7] [Ascend] Minimize lmdeploy MTP hooks for dlinfer Keep only the generic draft-step and accepted-token metadata plumbing in lmdeploy so the dlinfer backend can drive Ascend multi-token state updates without broad runtime hooks in the core runtime. Made-with: Cursor --- .../backends/dlinfer/ascend/op_backend.py | 43 ++++++++--- .../pytorch/backends/dlinfer/attention.py | 2 + .../pytorch/kernels/dlinfer/pagedattention.py | 8 ++ lmdeploy/pytorch/models/deepseek_mtp.py | 6 ++ lmdeploy/pytorch/models/qwen3_5_mtp.py | 9 ++- .../spec_decode/proposers/deepseek_mtp.py | 76 ++++++++++++++++++- lmdeploy/pytorch/spec_decode/spec_agent.py | 10 +-- .../pytorch/strategies/ar_spec/step_inputs.py | 26 ++++++- 8 files changed, 160 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 40439165c0..0dd3634e81 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -158,14 +158,29 @@ def update_step_context(cls, step_context): is_prefill_no_cache = False is_multi_token_decoding = False actual_seq_lengths_q = None + num_accepted_tokens = None + # Pre-compute CPU q_seqlens for decode path (reused in get_cpu_seqlens) + q_seqlens_cpu_early = None if not step_context.is_decoding: is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) else: - # Speculative decoding: main model decodes multiple tokens per sequence - is_multi_token_decoding = torch.max(step_context.q_seqlens).item() > 1 + q_seqlens_cpu_early = step_context.q_seqlens.cpu() + is_multi_token_decoding = q_seqlens_cpu_early.max().item() > 1 effective_is_decoding = step_context.is_decoding and not is_multi_token_decoding if is_multi_token_decoding: - actual_seq_lengths_q = step_context.q_seqlens.cpu().cumsum(0).to(torch.int32) + actual_seq_lengths_q = q_seqlens_cpu_early.cumsum(0).to(torch.int32) + if step_context.is_decoding and step_context.model_metas is not None: + accepted = [] + for model_meta in step_context.model_metas: + if isinstance(model_meta, dict): + accepted.append(int(model_meta.get('num_accepted_tokens', 1))) + else: + accepted.append(1) + num_accepted_tokens = torch.tensor( + accepted, + dtype=torch.int32, + device=step_context.block_offsets.device, + ) if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) if step_context.kv_seqlens.dtype != torch.int32: @@ -181,7 +196,8 @@ def get_total_slots(): cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): + def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding, + _q_seqlens_cpu_early=None): """Get sequence lengths on CPU. Returns: @@ -193,7 +209,7 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): q_seqlens_cpu = None kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_multi_token_decoding: - q_seqlens_cpu = step_context.q_seqlens.cpu() + q_seqlens_cpu = _q_seqlens_cpu_early kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_prefill_no_cache: q_seqlens_cpu = step_context.q_seqlens.cpu() @@ -257,9 +273,14 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_ device=step_context.block_offsets.device), diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: + mask_width = 2048 + causal_width = min(max_kv_seq_len, mask_width) attention_mask.append( - torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device), - diagonal=1)) + torch.triu(torch.ones(mask_width, + mask_width, + dtype=torch.bool, + device=step_context.block_offsets.device), + diagonal=causal_width - max_q_seq_len + 1)) kv_start_indices = torch.cat(kv_start_indices) @@ -286,7 +307,7 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group): if ep_size <= 1: return 0, 0, 0 # get padded_tokens_current_rank - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and effective_is_decoding if is_graph: from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size actual_tokens_current_rank = step_context.q_seqlens.shape[0] @@ -320,7 +341,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): if ep_size <= 1: return DlinferMoECommType.ALLGATHER mc2_token_capacity = init_mc2_token_capacity(tp_size) - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and effective_is_decoding if is_graph: max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size if SocVersion.is_A2(): @@ -363,7 +384,7 @@ def get_moe_group_name(group): return group_name q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(effective_is_decoding, is_prefill_no_cache, - is_multi_token_decoding) + is_multi_token_decoding, q_seqlens_cpu_early) q_seqlens_list, kv_seqlens_list = get_list_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) max_q_seq_len, max_kv_seq_len = get_max_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_list, @@ -418,6 +439,8 @@ def get_moe_group_name(group): has_initial_state=has_initial_state, is_multi_token_decoding=is_multi_token_decoding, actual_seq_lengths_q=actual_seq_lengths_q, + num_accepted_tokens=num_accepted_tokens, + kv_seqlens_device=step_context.kv_seqlens, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 232ced9e59..702f1457fb 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -21,6 +21,8 @@ class DlinferAttentionMetadata(AttentionMetadata): has_initial_state: Tensor | None = None is_multi_token_decoding: bool = False actual_seq_lengths_q: Tensor | None = None + num_accepted_tokens: Tensor | None = None + kv_seqlens_device: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index a913dc024b..5814e91970 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -141,6 +141,14 @@ def paged_attention_fwd( actual_seq_lengths_q: Tensor | None = None, ): if is_multi_token_decoding: + if actual_seq_lengths_q is None: + raise ValueError('MTP multi-token decode requires actual_seq_lengths_q for TND attention.') + if actual_seq_lengths_q.dim() != 1 or kv_seqlens.dim() != 1: + raise ValueError('TND attention expects 1D q/kv length tensors.') + if block_offsets.size(0) != actual_seq_lengths_q.numel(): + raise ValueError('TND attention expects per-sequence block tables.') + if kv_seqlens.numel() != actual_seq_lengths_q.numel(): + raise ValueError('TND attention expects kv lengths per sequence.') # MTP verify is semantically a "mini-prefill": multiple tokens per # sequence with TND layout, sparse_mode=3 and causal mask. # Reuse paged_prefill_attention, passing cumulative q lengths. diff --git a/lmdeploy/pytorch/models/deepseek_mtp.py b/lmdeploy/pytorch/models/deepseek_mtp.py index a36a14cd34..073efce304 100644 --- a/lmdeploy/pytorch/models/deepseek_mtp.py +++ b/lmdeploy/pytorch/models/deepseek_mtp.py @@ -601,6 +601,11 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata target_hidden_states = context.target_hidden_states + spec_step_idx = 0 + if context.model_metas: + model_meta = context.model_metas[0] + if isinstance(model_meta, dict): + spec_step_idx = int(model_meta.get('spec_step_idx', 0)) return dict( input_ids=input_ids, position_ids=position_ids, @@ -608,6 +613,7 @@ def prepare_inputs_for_generation( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, target_hidden_states=target_hidden_states, + spec_step_idx=spec_step_idx, ) def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], diff --git a/lmdeploy/pytorch/models/qwen3_5_mtp.py b/lmdeploy/pytorch/models/qwen3_5_mtp.py index 78cdd7172e..312e168425 100644 --- a/lmdeploy/pytorch/models/qwen3_5_mtp.py +++ b/lmdeploy/pytorch/models/qwen3_5_mtp.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. - from collections.abc import Iterable from typing import Any @@ -227,6 +226,7 @@ def forward( attn_metadata: Any, inputs_embeds: torch.Tensor | None = None, mrope_position_ids: torch.Tensor | None = None, + spec_step_idx: int = 0, **kwargs, ): """Model forward, return logits.""" @@ -244,6 +244,7 @@ def forward( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, mrope_position_ids=mrope_position_ids, + spec_step_idx=spec_step_idx, previous_hidden_states=target_hidden_states, all_routed_experts=all_routed_experts, ) @@ -294,6 +295,11 @@ def prepare_inputs_for_generation( attn_metadata = context.attn_metadata target_hidden_states = context.target_hidden_states mrope_position_ids = getattr(context, 'mrope_position_ids', None) + spec_step_idx = 0 + if context.model_metas: + model_meta = context.model_metas[0] + if isinstance(model_meta, dict): + spec_step_idx = int(model_meta.get('spec_step_idx', 0)) if context.target_inputs_embeds is not None: inputs_embeds = context.target_inputs_embeds @@ -305,6 +311,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, target_hidden_states=target_hidden_states, mrope_position_ids=mrope_position_ids, + spec_step_idx=spec_step_idx, ) def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], diff --git a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py index c0d65ca33b..92b29e0477 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py +++ b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py @@ -14,6 +14,77 @@ @SPEC_PROPOSERS.register_module(name='deepseek_mtp') class DeepseekMTP(BaseSpecProposer): + @staticmethod + def get_spec_step_idx(model_inputs: ModelInputs) -> int: + """Read the current draft step from model_metas.""" + model_metas = model_inputs.model_metas + if not model_metas: + return 0 + model_meta = model_metas[0] + if not isinstance(model_meta, dict): + return 0 + return int(model_meta.get('spec_step_idx', 0)) + + @staticmethod + def with_spec_step_idx( + model_metas: list[dict] | None, + batch_size: int, + spec_step_idx: int, + ): + """Attach spec_step_idx to every batch meta entry.""" + if model_metas is None: + model_metas = [None] * batch_size + + updated = [] + for batch_idx in range(batch_size): + model_meta = model_metas[batch_idx] if batch_idx < len(model_metas) else None + if model_meta is None: + model_meta = {} + else: + model_meta = dict(model_meta) + model_meta['spec_step_idx'] = spec_step_idx + updated.append(model_meta) + return updated + + def update_inputs_decoding( + self, + model_inputs: ModelInputs, + extra_inputs: ARSpecExtraInputs, + next_input_ids: torch.Tensor, + target_hidden_states: torch.Tensor, + model_metas: list[dict], + ): + """Update decoding inputs with deepseek-style spec step metadata.""" + new_inputs = super().update_inputs_decoding( + model_inputs, + extra_inputs, + next_input_ids, + target_hidden_states, + model_metas, + ) + return new_inputs.clone( + model_metas=self.with_spec_step_idx( + model_metas, + new_inputs.seq_length.size(0), + 0, + ) + ) + + def get_logits(self, hidden_states: torch.Tensor, spec_step_idx: int = 0): + """Get logits of deepseek/qwen-style MTP draft models.""" + draft_model = self.model + if not isinstance(draft_model, torch.nn.Module): + draft_model = draft_model.model + + if hasattr(draft_model, 'get_logits'): + try: + logits = draft_model.get_logits(hidden_states, spec_step_idx=spec_step_idx) + except TypeError: + logits = draft_model.get_logits(hidden_states) + else: + logits = self.target_model.get_logits(hidden_states) + return logits + def get_outputs(self, model_outputs: dict[str, torch.Tensor], model_inputs: ModelInputs, @@ -21,6 +92,9 @@ def get_outputs(self, """Get outputs.""" hidden_states = model_outputs['hidden_states'] model_metas = model_outputs['model_metas'] + if model_metas is None: + model_metas = model_inputs.model_metas + spec_step_idx = self.get_spec_step_idx(model_inputs) if extra_inputs is not None: last_token_loc = extra_inputs.last_token_indices target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc] @@ -28,6 +102,6 @@ def get_outputs(self, else: target_hidden_states = hidden_states - logits = self.get_logits(hidden_states)[0] + logits = self.get_logits(hidden_states, spec_step_idx=spec_step_idx)[0] draft_token_ids = logits.argmax(dim=-1, keepdim=True) return draft_token_ids, model_metas, target_hidden_states diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 8a0e2ba64a..eec1be9b50 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -3,8 +3,6 @@ import torch from torch.profiler import record_function -from lmdeploy.utils import get_logger - from ..backends import get_backend from ..config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig from ..engine.cache_engine import CacheEngine @@ -17,8 +15,6 @@ from .proposers.base import build_specdecode_proposer from .reject_sampler import RejectionSampler -logger = get_logger('lmdeploy') - def _expand_sampling_inputs(sampling_inputs: SamplingInputs, num_tokens: int) -> SamplingInputs: """Expand per-batch SamplingInputs to per-token by repeating each batch @@ -404,7 +400,11 @@ async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecEx if loop_idx < loop_count - 1: step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0)) inputs = inputs.step(draft_token_ids.transpose(0, 1), step_seqlens) - inputs.model_metas = model_metas + inputs.model_metas = self.proposer.with_spec_step_idx( + model_metas, + inputs.seq_length.size(0), + loop_idx + 1, + ) inputs.target_hidden_states = target_hidden_states if inputs.target_position_ids is not None: inputs.target_position_ids += 1 diff --git a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py index aa8661726e..9178bc2ea8 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py @@ -93,6 +93,24 @@ def _reindex_model_inputs_arspec( ) +def _with_num_accepted_tokens(model_metas: Any, num_accepted_tokens: torch.Tensor): + """Attach accepted-token counts to per-sequence model metas.""" + batch_size = num_accepted_tokens.size(0) + if model_metas is None: + model_metas = [None] * batch_size + + updated = [] + for batch_idx in range(batch_size): + model_meta = model_metas[batch_idx] if batch_idx < len(model_metas) else None + if model_meta is None: + model_meta = {} + else: + model_meta = dict(model_meta) + model_meta['num_accepted_tokens'] = int(num_accepted_tokens[batch_idx].item()) + updated.append(model_meta) + return updated + + @dataclass class ARSpecStepInputs(StepInputs): """AR Spec paradigm step inputs.""" @@ -118,9 +136,11 @@ def merge_prefill( [next_token_ids_expanded, extra_outputs.draft_token_ids], dim=-1) max_q_seqlen = next_token_ids_expanded.size(-1) next_token_ids_flat = next_token_ids_expanded.flatten()[None, :] + num_accepted_tokens = next_token_ids.new_ones(next_token_ids.size(0)) inputs = get_model_inputs_next_decoding( inputs, next_token_ids_flat, - max_q_seqlen=max_q_seqlen, model_metas=model_metas) + max_q_seqlen=max_q_seqlen, + model_metas=_with_num_accepted_tokens(model_metas, num_accepted_tokens)) # update mrope pos ids mrope_pos_ids = inputs.mrope_pos_ids @@ -175,13 +195,13 @@ def step_decode( # advance model state model_inputs.is_decoding = True - model_inputs.model_metas = model_metas + step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens + model_inputs.model_metas = _with_num_accepted_tokens(model_metas, step_seqlens) # update extra inputs extra_inputs.output_token_ids = extra_outputs.draft_token_ids # update inputs with rejected token adjustment - step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens batch_size = step_seqlens.size(0) input_ids = next_token_ids.new_empty((batch_size, num_spec_tokens + 1)) input_ids[:, 0] = next_token_ids From b853f1ce05d9a6c65b1595eab8314f2c728c9789 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Thu, 14 May 2026 05:10:03 +0000 Subject: [PATCH 3/7] [ascend] fix attn mask --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 0dd3634e81..ffd0a3b51b 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -273,14 +273,7 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_ device=step_context.block_offsets.device), diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: - mask_width = 2048 - causal_width = min(max_kv_seq_len, mask_width) - attention_mask.append( - torch.triu(torch.ones(mask_width, - mask_width, - dtype=torch.bool, - device=step_context.block_offsets.device), - diagonal=causal_width - max_q_seq_len + 1)) + attention_mask.append(torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device), diagonal=1)) kv_start_indices = torch.cat(kv_start_indices) From 698dd00737294e24bb3ad3e45c90a527af7b4113 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 18 May 2026 08:04:15 +0000 Subject: [PATCH 4/7] Refactor GDN and conv1d computation flow --- .../backends/dlinfer/ascend/op_backend.py | 107 ++++++++++++------ .../pytorch/backends/dlinfer/attention.py | 7 +- .../pytorch/kernels/dlinfer/pagedattention.py | 14 +-- 3 files changed, 77 insertions(+), 51 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index ffd0a3b51b..1cf8410018 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -13,6 +13,7 @@ from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from lmdeploy.utils import get_logger from ..moe import DlinferMoECommType, DlinferMoeMetadata @@ -156,19 +157,17 @@ def update_step_context(cls, step_context): block_num, block_size, *_ = step_context.kv_caches[0][0].shape is_prefill_no_cache = False - is_multi_token_decoding = False - actual_seq_lengths_q = None - num_accepted_tokens = None - # Pre-compute CPU q_seqlens for decode path (reused in get_cpu_seqlens) - q_seqlens_cpu_early = None + if not step_context.is_decoding: is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + is_multi_token_decoding = False + is_decoding = False else: - q_seqlens_cpu_early = step_context.q_seqlens.cpu() - is_multi_token_decoding = q_seqlens_cpu_early.max().item() > 1 - effective_is_decoding = step_context.is_decoding and not is_multi_token_decoding - if is_multi_token_decoding: - actual_seq_lengths_q = q_seqlens_cpu_early.cumsum(0).to(torch.int32) + # Device-side scalar op; avoids a full D2H copy on every regular decode step + is_multi_token_decoding = step_context.q_seqlens.max().item() > 1 + # is_decoding: True only for regular single-token decode (original semantics) + is_decoding = not is_multi_token_decoding + if step_context.is_decoding and step_context.model_metas is not None: accepted = [] for model_meta in step_context.model_metas: @@ -196,8 +195,7 @@ def get_total_slots(): cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding, - _q_seqlens_cpu_early=None): + def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): """Get sequence lengths on CPU. Returns: @@ -205,11 +203,11 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding, kv_seqlens_cpu: kv sequence lengths (per sequence), used for list/max seqlens calculation. """ - if is_decoding and not is_multi_token_decoding: + if is_decoding: q_seqlens_cpu = None kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_multi_token_decoding: - q_seqlens_cpu = _q_seqlens_cpu_early + q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_prefill_no_cache: q_seqlens_cpu = step_context.q_seqlens.cpu() @@ -244,7 +242,8 @@ def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None): return torch.arange(1, batch_size + 1, dtype=torch.int32) elif is_prefill_no_cache: return q_seqlens_cpu - return q_seqlens_cpu.cumsum(dim=0) + # for paged_prefill, eg. MTP, prefix caching + return q_seqlens_cpu.cumsum(dim=0).to(torch.int32) def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): @@ -300,7 +299,7 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group): if ep_size <= 1: return 0, 0, 0 # get padded_tokens_current_rank - is_graph = cls.enable_graph and effective_is_decoding + is_graph = cls.enable_graph and is_decoding if is_graph: from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size actual_tokens_current_rank = step_context.q_seqlens.shape[0] @@ -334,7 +333,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): if ep_size <= 1: return DlinferMoECommType.ALLGATHER mc2_token_capacity = init_mc2_token_capacity(tp_size) - is_graph = cls.enable_graph and effective_is_decoding + is_graph = cls.enable_graph and is_decoding if is_graph: max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size if SocVersion.is_A2(): @@ -376,18 +375,18 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(effective_is_decoding, is_prefill_no_cache, - is_multi_token_decoding, q_seqlens_cpu_early) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_cpu, + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache, + is_multi_token_decoding) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_list, + max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list) - kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(effective_is_decoding, + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) - q_seqlens_cpu = update_q_seqlens(effective_is_decoding, is_prefill_no_cache, q_seqlens_cpu) - + q_seqlens_cpu = update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu) + if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE' @@ -403,18 +402,58 @@ def get_moe_group_name(group): cu_seqlens = None has_initial_state = None - + spec_conv_offsets = None + spec_state_offsets = None + cache_seqlens = None is_gated_delta = step_context.model_config.is_gated_delta if is_gated_delta: - q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, - device=step_context.q_seqlens.device) - cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() - if not effective_is_decoding: - has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens + + q_start_loc = step_context.q_start_loc.to(dtype=q_seqlens.dtype, + device=q_seqlens.device) + cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int() + cache_seqlens = (kv_seqlens - q_seqlens).contiguous() + + + states_shapes = step_context.model_config.states_shapes + if not is_decoding and not is_multi_token_decoding and len(states_shapes) > 0: + has_initial_state = ~(q_seqlens == kv_seqlens) + num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + # # Conv ring buffer: conv_state_len = conv_kernel_size + num_spec_tokens. + conv_state_len = states_shapes[0][0][0] + conv_kernel_size = conv_state_len - num_spec_tokens + + if num_spec_tokens > 0: + state_slots = 1 + num_spec_tokens + spec_state_offsets = ( + torch.remainder(cache_seqlens, state_slots), + torch.remainder(kv_seqlens, state_slots), + ) + + range_idx = torch.arange( + -conv_kernel_size, + 0, + device=cache_seqlens.device, + dtype=torch.int32, + ) + # Read the (conv_kernel_size - 1) tokens preceding the current write + # window from the circular buffer. + read_conv_offsets = torch.remainder( + cache_seqlens[:, None] + range_idx[1:][None], + conv_state_len, + ).to(torch.int64) + # Write the last conv_kernel_size tokens of this prefill batch into + # circular-buffer slots so the next decode read aligns naturally. + write_conv_offsets = torch.remainder( + kv_seqlens[:, None] + range_idx[None], + conv_state_len, + ).to(torch.int64) + spec_conv_offsets = (read_conv_offsets, write_conv_offsets) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( - effective_is_decoding, + is_decoding, step_context.block_offsets, # cu_seqlens is only used in GDN and is passed down via q_start_loc. # Otherwise, q_start_loc is None. @@ -431,9 +470,9 @@ def get_moe_group_name(group): quant_meta=AscendKVQuantMeta.quant_meta, has_initial_state=has_initial_state, is_multi_token_decoding=is_multi_token_decoding, - actual_seq_lengths_q=actual_seq_lengths_q, - num_accepted_tokens=num_accepted_tokens, - kv_seqlens_device=step_context.kv_seqlens, + spec_conv_offsets=spec_conv_offsets, + spec_state_offsets=spec_state_offsets, + cache_seqlens=cache_seqlens, ) step_context.attn_metadata = attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 702f1457fb..3cf0fcc1a1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -20,9 +20,9 @@ class DlinferAttentionMetadata(AttentionMetadata): cu_seq_lens_kv: Tensor | None = None has_initial_state: Tensor | None = None is_multi_token_decoding: bool = False - actual_seq_lengths_q: Tensor | None = None - num_accepted_tokens: Tensor | None = None - kv_seqlens_device: Tensor | None = None + spec_conv_offsets: Sequence[Tensor] = tuple() + spec_state_offsets: Sequence[Tensor] = tuple() + cache_seqlens: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): @@ -148,7 +148,6 @@ def forward( kv_zeros=kv_zeros, quant_bits=quant_bits, is_multi_token_decoding=attn_metadata.is_multi_token_decoding, - actual_seq_lengths_q=attn_metadata.actual_seq_lengths_q, ) return attn_output diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 5814e91970..5819602b96 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -138,20 +138,8 @@ def paged_attention_fwd( kv_zeros: Tensor | None = None, quant_bits: int | None = 0, is_multi_token_decoding: bool = False, - actual_seq_lengths_q: Tensor | None = None, ): if is_multi_token_decoding: - if actual_seq_lengths_q is None: - raise ValueError('MTP multi-token decode requires actual_seq_lengths_q for TND attention.') - if actual_seq_lengths_q.dim() != 1 or kv_seqlens.dim() != 1: - raise ValueError('TND attention expects 1D q/kv length tensors.') - if block_offsets.size(0) != actual_seq_lengths_q.numel(): - raise ValueError('TND attention expects per-sequence block tables.') - if kv_seqlens.numel() != actual_seq_lengths_q.numel(): - raise ValueError('TND attention expects kv lengths per sequence.') - # MTP verify is semantically a "mini-prefill": multiple tokens per - # sequence with TND layout, sparse_mode=3 and causal mask. - # Reuse paged_prefill_attention, passing cumulative q lengths. return prefill_attention( query_states, key_states, @@ -161,7 +149,7 @@ def paged_attention_fwd( value_cache, block_offsets, q_start_loc, - actual_seq_lengths_q, + q_seqlens, kv_seqlens, cu_seq_lens_kv, max_q_seq_len, From 2a90115eda38a26895dbf9bbfba13e93b236e5a4 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 18 May 2026 08:06:42 +0000 Subject: [PATCH 5/7] Refactor GDN and conv1d computation flow --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 1cf8410018..0d664869b1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -195,7 +195,7 @@ def get_total_slots(): cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): + def get_cpu_seqlens(is_decoding, is_prefill_no_cache): """Get sequence lengths on CPU. Returns: @@ -206,9 +206,6 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache, is_multi_token_decoding): if is_decoding: q_seqlens_cpu = None kv_seqlens_cpu = step_context.kv_seqlens.cpu() - elif is_multi_token_decoding: - q_seqlens_cpu = step_context.q_seqlens.cpu() - kv_seqlens_cpu = step_context.kv_seqlens.cpu() elif is_prefill_no_cache: q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = q_seqlens_cpu @@ -375,8 +372,7 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache, - is_multi_token_decoding) + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache) q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list, From 16605d127bdea32fda8d160b3706b87f21fac2c4 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Thu, 21 May 2026 09:17:46 +0000 Subject: [PATCH 6/7] [ascend] remove unused code --- .../backends/dlinfer/ascend/op_backend.py | 13 ---- .../pytorch/backends/dlinfer/attention.py | 1 - .../pytorch/kernels/dlinfer/pagedattention.py | 29 +------ lmdeploy/pytorch/models/deepseek_mtp.py | 6 -- lmdeploy/pytorch/models/qwen3_5_mtp.py | 9 +-- .../spec_decode/proposers/deepseek_mtp.py | 76 +------------------ lmdeploy/pytorch/spec_decode/spec_agent.py | 10 +-- .../pytorch/strategies/ar_spec/step_inputs.py | 27 +------ 8 files changed, 11 insertions(+), 160 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 0d664869b1..24aece7415 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -163,23 +163,10 @@ def update_step_context(cls, step_context): is_multi_token_decoding = False is_decoding = False else: - # Device-side scalar op; avoids a full D2H copy on every regular decode step is_multi_token_decoding = step_context.q_seqlens.max().item() > 1 # is_decoding: True only for regular single-token decode (original semantics) is_decoding = not is_multi_token_decoding - if step_context.is_decoding and step_context.model_metas is not None: - accepted = [] - for model_meta in step_context.model_metas: - if isinstance(model_meta, dict): - accepted.append(int(model_meta.get('num_accepted_tokens', 1))) - else: - accepted.append(1) - num_accepted_tokens = torch.tensor( - accepted, - dtype=torch.int32, - device=step_context.block_offsets.device, - ) if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) if step_context.kv_seqlens.dtype != torch.int32: diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 3cf0fcc1a1..17e6773c51 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -147,7 +147,6 @@ def forward( kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits, - is_multi_token_decoding=attn_metadata.is_multi_token_decoding, ) return attn_output diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 5819602b96..5e75de0a5b 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -137,35 +137,8 @@ def paged_attention_fwd( kv_scales: Tensor | None = None, kv_zeros: Tensor | None = None, quant_bits: int | None = 0, - is_multi_token_decoding: bool = False, ): - if is_multi_token_decoding: - return prefill_attention( - query_states, - key_states, - value_states, - attn_output, - key_cache, - value_cache, - block_offsets, - q_start_loc, - q_seqlens, - kv_seqlens, - cu_seq_lens_kv, - max_q_seq_len, - max_kv_seq_len, - block_size, - num_heads, - num_kv_heads, - v_head_size, - attn_mask, - softmax_scale, - is_prefill_no_cache=False, - kv_scales=kv_scales, - kv_zeros=kv_zeros, - quant_bits=quant_bits, - ) - elif not is_decoding: + if not is_decoding: return prefill_attention( query_states, key_states, diff --git a/lmdeploy/pytorch/models/deepseek_mtp.py b/lmdeploy/pytorch/models/deepseek_mtp.py index 073efce304..a36a14cd34 100644 --- a/lmdeploy/pytorch/models/deepseek_mtp.py +++ b/lmdeploy/pytorch/models/deepseek_mtp.py @@ -601,11 +601,6 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata target_hidden_states = context.target_hidden_states - spec_step_idx = 0 - if context.model_metas: - model_meta = context.model_metas[0] - if isinstance(model_meta, dict): - spec_step_idx = int(model_meta.get('spec_step_idx', 0)) return dict( input_ids=input_ids, position_ids=position_ids, @@ -613,7 +608,6 @@ def prepare_inputs_for_generation( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, target_hidden_states=target_hidden_states, - spec_step_idx=spec_step_idx, ) def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], diff --git a/lmdeploy/pytorch/models/qwen3_5_mtp.py b/lmdeploy/pytorch/models/qwen3_5_mtp.py index 312e168425..78cdd7172e 100644 --- a/lmdeploy/pytorch/models/qwen3_5_mtp.py +++ b/lmdeploy/pytorch/models/qwen3_5_mtp.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. + from collections.abc import Iterable from typing import Any @@ -226,7 +227,6 @@ def forward( attn_metadata: Any, inputs_embeds: torch.Tensor | None = None, mrope_position_ids: torch.Tensor | None = None, - spec_step_idx: int = 0, **kwargs, ): """Model forward, return logits.""" @@ -244,7 +244,6 @@ def forward( attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, mrope_position_ids=mrope_position_ids, - spec_step_idx=spec_step_idx, previous_hidden_states=target_hidden_states, all_routed_experts=all_routed_experts, ) @@ -295,11 +294,6 @@ def prepare_inputs_for_generation( attn_metadata = context.attn_metadata target_hidden_states = context.target_hidden_states mrope_position_ids = getattr(context, 'mrope_position_ids', None) - spec_step_idx = 0 - if context.model_metas: - model_meta = context.model_metas[0] - if isinstance(model_meta, dict): - spec_step_idx = int(model_meta.get('spec_step_idx', 0)) if context.target_inputs_embeds is not None: inputs_embeds = context.target_inputs_embeds @@ -311,7 +305,6 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, target_hidden_states=target_hidden_states, mrope_position_ids=mrope_position_ids, - spec_step_idx=spec_step_idx, ) def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], diff --git a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py index 92b29e0477..c0d65ca33b 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py +++ b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py @@ -14,77 +14,6 @@ @SPEC_PROPOSERS.register_module(name='deepseek_mtp') class DeepseekMTP(BaseSpecProposer): - @staticmethod - def get_spec_step_idx(model_inputs: ModelInputs) -> int: - """Read the current draft step from model_metas.""" - model_metas = model_inputs.model_metas - if not model_metas: - return 0 - model_meta = model_metas[0] - if not isinstance(model_meta, dict): - return 0 - return int(model_meta.get('spec_step_idx', 0)) - - @staticmethod - def with_spec_step_idx( - model_metas: list[dict] | None, - batch_size: int, - spec_step_idx: int, - ): - """Attach spec_step_idx to every batch meta entry.""" - if model_metas is None: - model_metas = [None] * batch_size - - updated = [] - for batch_idx in range(batch_size): - model_meta = model_metas[batch_idx] if batch_idx < len(model_metas) else None - if model_meta is None: - model_meta = {} - else: - model_meta = dict(model_meta) - model_meta['spec_step_idx'] = spec_step_idx - updated.append(model_meta) - return updated - - def update_inputs_decoding( - self, - model_inputs: ModelInputs, - extra_inputs: ARSpecExtraInputs, - next_input_ids: torch.Tensor, - target_hidden_states: torch.Tensor, - model_metas: list[dict], - ): - """Update decoding inputs with deepseek-style spec step metadata.""" - new_inputs = super().update_inputs_decoding( - model_inputs, - extra_inputs, - next_input_ids, - target_hidden_states, - model_metas, - ) - return new_inputs.clone( - model_metas=self.with_spec_step_idx( - model_metas, - new_inputs.seq_length.size(0), - 0, - ) - ) - - def get_logits(self, hidden_states: torch.Tensor, spec_step_idx: int = 0): - """Get logits of deepseek/qwen-style MTP draft models.""" - draft_model = self.model - if not isinstance(draft_model, torch.nn.Module): - draft_model = draft_model.model - - if hasattr(draft_model, 'get_logits'): - try: - logits = draft_model.get_logits(hidden_states, spec_step_idx=spec_step_idx) - except TypeError: - logits = draft_model.get_logits(hidden_states) - else: - logits = self.target_model.get_logits(hidden_states) - return logits - def get_outputs(self, model_outputs: dict[str, torch.Tensor], model_inputs: ModelInputs, @@ -92,9 +21,6 @@ def get_outputs(self, """Get outputs.""" hidden_states = model_outputs['hidden_states'] model_metas = model_outputs['model_metas'] - if model_metas is None: - model_metas = model_inputs.model_metas - spec_step_idx = self.get_spec_step_idx(model_inputs) if extra_inputs is not None: last_token_loc = extra_inputs.last_token_indices target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc] @@ -102,6 +28,6 @@ def get_outputs(self, else: target_hidden_states = hidden_states - logits = self.get_logits(hidden_states, spec_step_idx=spec_step_idx)[0] + logits = self.get_logits(hidden_states)[0] draft_token_ids = logits.argmax(dim=-1, keepdim=True) return draft_token_ids, model_metas, target_hidden_states diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index eec1be9b50..8a0e2ba64a 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -3,6 +3,8 @@ import torch from torch.profiler import record_function +from lmdeploy.utils import get_logger + from ..backends import get_backend from ..config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig from ..engine.cache_engine import CacheEngine @@ -15,6 +17,8 @@ from .proposers.base import build_specdecode_proposer from .reject_sampler import RejectionSampler +logger = get_logger('lmdeploy') + def _expand_sampling_inputs(sampling_inputs: SamplingInputs, num_tokens: int) -> SamplingInputs: """Expand per-batch SamplingInputs to per-token by repeating each batch @@ -400,11 +404,7 @@ async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecEx if loop_idx < loop_count - 1: step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0)) inputs = inputs.step(draft_token_ids.transpose(0, 1), step_seqlens) - inputs.model_metas = self.proposer.with_spec_step_idx( - model_metas, - inputs.seq_length.size(0), - loop_idx + 1, - ) + inputs.model_metas = model_metas inputs.target_hidden_states = target_hidden_states if inputs.target_position_ids is not None: inputs.target_position_ids += 1 diff --git a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py index 9178bc2ea8..044a587f3a 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py @@ -93,24 +93,6 @@ def _reindex_model_inputs_arspec( ) -def _with_num_accepted_tokens(model_metas: Any, num_accepted_tokens: torch.Tensor): - """Attach accepted-token counts to per-sequence model metas.""" - batch_size = num_accepted_tokens.size(0) - if model_metas is None: - model_metas = [None] * batch_size - - updated = [] - for batch_idx in range(batch_size): - model_meta = model_metas[batch_idx] if batch_idx < len(model_metas) else None - if model_meta is None: - model_meta = {} - else: - model_meta = dict(model_meta) - model_meta['num_accepted_tokens'] = int(num_accepted_tokens[batch_idx].item()) - updated.append(model_meta) - return updated - - @dataclass class ARSpecStepInputs(StepInputs): """AR Spec paradigm step inputs.""" @@ -136,11 +118,8 @@ def merge_prefill( [next_token_ids_expanded, extra_outputs.draft_token_ids], dim=-1) max_q_seqlen = next_token_ids_expanded.size(-1) next_token_ids_flat = next_token_ids_expanded.flatten()[None, :] - num_accepted_tokens = next_token_ids.new_ones(next_token_ids.size(0)) inputs = get_model_inputs_next_decoding( - inputs, next_token_ids_flat, - max_q_seqlen=max_q_seqlen, - model_metas=_with_num_accepted_tokens(model_metas, num_accepted_tokens)) + inputs, next_token_ids_flat, max_q_seqlen=max_q_seqlen, model_metas=model_metas) # update mrope pos ids mrope_pos_ids = inputs.mrope_pos_ids @@ -195,13 +174,13 @@ def step_decode( # advance model state model_inputs.is_decoding = True - step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens - model_inputs.model_metas = _with_num_accepted_tokens(model_metas, step_seqlens) + model_inputs.model_metas = model_metas # update extra inputs extra_inputs.output_token_ids = extra_outputs.draft_token_ids # update inputs with rejected token adjustment + step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens batch_size = step_seqlens.size(0) input_ids = next_token_ids.new_empty((batch_size, num_spec_tokens + 1)) input_ids[:, 0] = next_token_ids From 58839571838f694337d3399d9891f72be652cbfa Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Thu, 21 May 2026 09:52:47 +0000 Subject: [PATCH 7/7] [ascend] remote unused code --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 4 +++- lmdeploy/pytorch/strategies/ar_spec/step_inputs.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 24aece7415..41ea0b536d 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -256,7 +256,9 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_ device=step_context.block_offsets.device), diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: - attention_mask.append(torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device), diagonal=1)) + attention_mask.append( + torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device), + diagonal=1)) kv_start_indices = torch.cat(kv_start_indices) diff --git a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py index 044a587f3a..aa8661726e 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py +++ b/lmdeploy/pytorch/strategies/ar_spec/step_inputs.py @@ -119,7 +119,8 @@ def merge_prefill( max_q_seqlen = next_token_ids_expanded.size(-1) next_token_ids_flat = next_token_ids_expanded.flatten()[None, :] inputs = get_model_inputs_next_decoding( - inputs, next_token_ids_flat, max_q_seqlen=max_q_seqlen, model_metas=model_metas) + inputs, next_token_ids_flat, + max_q_seqlen=max_q_seqlen, model_metas=model_metas) # update mrope pos ids mrope_pos_ids = inputs.mrope_pos_ids