diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 8cc41f5df8..41ea0b536d 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -13,6 +13,7 @@ from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from lmdeploy.utils import get_logger from ..moe import DlinferMoECommType, DlinferMoeMetadata @@ -156,8 +157,16 @@ def update_step_context(cls, step_context): block_num, block_size, *_ = step_context.kv_caches[0][0].shape is_prefill_no_cache = False + if not step_context.is_decoding: is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + is_multi_token_decoding = False + is_decoding = False + else: + is_multi_token_decoding = step_context.q_seqlens.max().item() > 1 + # is_decoding: True only for regular single-token decode (original semantics) + is_decoding = not is_multi_token_decoding + if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) if step_context.kv_seqlens.dtype != torch.int32: @@ -180,8 +189,6 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache): q_seqlens_cpu: query sequence lengths (per sequence). kv_seqlens_cpu: kv sequence lengths (per sequence), used for list/max seqlens calculation. - kv_seqlens_expanded: kv sequence lengths expanded per token via - repeat_interleave, used for attention metadata. """ if is_decoding: q_seqlens_cpu = None @@ -219,7 +226,8 @@ def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None): return torch.arange(1, batch_size + 1, dtype=torch.int32) elif is_prefill_no_cache: return q_seqlens_cpu - return q_seqlens_cpu.cumsum(dim=0) + # for paged_prefill, eg. MTP, prefix caching + return q_seqlens_cpu.cumsum(dim=0).to(torch.int32) def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): @@ -277,7 +285,7 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group): if ep_size <= 1: return 0, 0, 0 # get padded_tokens_current_rank - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and is_decoding if is_graph: from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size actual_tokens_current_rank = step_context.q_seqlens.shape[0] @@ -311,7 +319,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): if ep_size <= 1: return DlinferMoECommType.ALLGATHER mc2_token_capacity = init_mc2_token_capacity(tp_size) - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and is_decoding if is_graph: max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size if SocVersion.is_A2(): @@ -353,17 +361,17 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu, + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list, + max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list) - kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) - q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu) - + q_seqlens_cpu = update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu) + if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE' @@ -379,18 +387,58 @@ def get_moe_group_name(group): cu_seqlens = None has_initial_state = None - + spec_conv_offsets = None + spec_state_offsets = None + cache_seqlens = None is_gated_delta = step_context.model_config.is_gated_delta if is_gated_delta: - q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, - device=step_context.q_seqlens.device) - cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() - if not step_context.is_decoding: - has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens + + q_start_loc = step_context.q_start_loc.to(dtype=q_seqlens.dtype, + device=q_seqlens.device) + cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int() + cache_seqlens = (kv_seqlens - q_seqlens).contiguous() + + + states_shapes = step_context.model_config.states_shapes + if not is_decoding and not is_multi_token_decoding and len(states_shapes) > 0: + has_initial_state = ~(q_seqlens == kv_seqlens) + num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + # # Conv ring buffer: conv_state_len = conv_kernel_size + num_spec_tokens. + conv_state_len = states_shapes[0][0][0] + conv_kernel_size = conv_state_len - num_spec_tokens + + if num_spec_tokens > 0: + state_slots = 1 + num_spec_tokens + spec_state_offsets = ( + torch.remainder(cache_seqlens, state_slots), + torch.remainder(kv_seqlens, state_slots), + ) + + range_idx = torch.arange( + -conv_kernel_size, + 0, + device=cache_seqlens.device, + dtype=torch.int32, + ) + # Read the (conv_kernel_size - 1) tokens preceding the current write + # window from the circular buffer. + read_conv_offsets = torch.remainder( + cache_seqlens[:, None] + range_idx[1:][None], + conv_state_len, + ).to(torch.int64) + # Write the last conv_kernel_size tokens of this prefill batch into + # circular-buffer slots so the next decode read aligns naturally. + write_conv_offsets = torch.remainder( + kv_seqlens[:, None] + range_idx[None], + conv_state_len, + ).to(torch.int64) + spec_conv_offsets = (read_conv_offsets, write_conv_offsets) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( - step_context.is_decoding, + is_decoding, step_context.block_offsets, # cu_seqlens is only used in GDN and is passed down via q_start_loc. # Otherwise, q_start_loc is None. @@ -406,6 +454,10 @@ def get_moe_group_name(group): quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, has_initial_state=has_initial_state, + is_multi_token_decoding=is_multi_token_decoding, + spec_conv_offsets=spec_conv_offsets, + spec_state_offsets=spec_state_offsets, + cache_seqlens=cache_seqlens, ) step_context.attn_metadata = attn_metadata @@ -462,6 +514,12 @@ def init(): logger.warning(f'Error during Ascend initialization: {str(e)}. ' 'Please check your Ascend environment configuration.') + try: + import dlinfer.framework.lmdeploy_ext.device # noqa: F401 — triggers vendor_device_init() + except ImportError: + logger.warning('dlinfer framework extensions not found. ' + 'Ascend-specific model patches will not be applied.') + try: from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton init_device_properties_triton() diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index a8eea27545..17e6773c51 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -19,6 +19,10 @@ class DlinferAttentionMetadata(AttentionMetadata): quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None has_initial_state: Tensor | None = None + is_multi_token_decoding: bool = False + spec_conv_offsets: Sequence[Tensor] = tuple() + spec_state_offsets: Sequence[Tensor] = tuple() + cache_seqlens: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9b26222d31..6c5c286a84 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -566,6 +566,7 @@ def from_config( target_cache_cfg: CacheConfig, target_model: str = None, dtype: str = 'auto', + device_type: str = 'auto', ): model = model or target_model model_config = ModelConfig.from_pretrained(model, @@ -574,6 +575,7 @@ def from_config( is_draft_model=True, spec_method=method, block_size=target_cache_cfg.block_size, + device_type=device_type, ) cache_config = None # include medusa diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 3c5a005daf..96e886eadf 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -114,5 +114,6 @@ def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, target_model=target_model, target_cache_cfg=cache_config, dtype=engine_config.dtype, + device_type=engine_config.device_type, ) return specdecode_config