Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 76 additions & 18 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.distributed import get_dist_manager
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
from lmdeploy.utils import get_logger

from ..moe import DlinferMoECommType, DlinferMoeMetadata
Expand Down Expand Up @@ -156,8 +157,16 @@ def update_step_context(cls, step_context):

block_num, block_size, *_ = step_context.kv_caches[0][0].shape
is_prefill_no_cache = False

if not step_context.is_decoding:
is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
is_multi_token_decoding = False
is_decoding = False
else:
is_multi_token_decoding = step_context.q_seqlens.max().item() > 1
# is_decoding: True only for regular single-token decode (original semantics)
is_decoding = not is_multi_token_decoding

if step_context.block_offsets.dtype != torch.int32:
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
if step_context.kv_seqlens.dtype != torch.int32:
Expand All @@ -180,8 +189,6 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache):
q_seqlens_cpu: query sequence lengths (per sequence).
kv_seqlens_cpu: kv sequence lengths (per sequence), used for
list/max seqlens calculation.
kv_seqlens_expanded: kv sequence lengths expanded per token via
repeat_interleave, used for attention metadata.
"""
if is_decoding:
q_seqlens_cpu = None
Expand Down Expand Up @@ -219,7 +226,8 @@ def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None):
return torch.arange(1, batch_size + 1, dtype=torch.int32)
elif is_prefill_no_cache:
return q_seqlens_cpu
return q_seqlens_cpu.cumsum(dim=0)
# for paged_prefill, eg. MTP, prefix caching
return q_seqlens_cpu.cumsum(dim=0).to(torch.int32)

def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list,
max_q_seq_len, max_kv_seq_len):
Expand Down Expand Up @@ -277,7 +285,7 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
if ep_size <= 1:
return 0, 0, 0
# get padded_tokens_current_rank
is_graph = cls.enable_graph and step_context.is_decoding
is_graph = cls.enable_graph and is_decoding
if is_graph:
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
actual_tokens_current_rank = step_context.q_seqlens.shape[0]
Expand Down Expand Up @@ -311,7 +319,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
if ep_size <= 1:
return DlinferMoECommType.ALLGATHER
mc2_token_capacity = init_mc2_token_capacity(tp_size)
is_graph = cls.enable_graph and step_context.is_decoding
is_graph = cls.enable_graph and is_decoding
if is_graph:
max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size
if SocVersion.is_A2():
Expand Down Expand Up @@ -353,17 +361,17 @@ def get_moe_group_name(group):
group_name = backend.get_hccl_comm_name(local_rank)
return group_name

q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu,
q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu,
kv_seqlens_cpu)
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list,
max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list)
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(is_decoding,
is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list, max_q_seq_len,
max_kv_seq_len)
q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu)

q_seqlens_cpu = update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu)
if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'
Expand All @@ -379,18 +387,58 @@ def get_moe_group_name(group):

cu_seqlens = None
has_initial_state = None

spec_conv_offsets = None
spec_state_offsets = None
cache_seqlens = None
is_gated_delta = step_context.model_config.is_gated_delta
if is_gated_delta:
q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype,
device=step_context.q_seqlens.device)
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()
if not step_context.is_decoding:
has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens)
q_seqlens = step_context.q_seqlens
kv_seqlens = step_context.kv_seqlens

q_start_loc = step_context.q_start_loc.to(dtype=q_seqlens.dtype,
device=q_seqlens.device)
cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
cache_seqlens = (kv_seqlens - q_seqlens).contiguous()


states_shapes = step_context.model_config.states_shapes
if not is_decoding and not is_multi_token_decoding and len(states_shapes) > 0:
has_initial_state = ~(q_seqlens == kv_seqlens)
num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens
# # Conv ring buffer: conv_state_len = conv_kernel_size + num_spec_tokens.
conv_state_len = states_shapes[0][0][0]
conv_kernel_size = conv_state_len - num_spec_tokens

if num_spec_tokens > 0:
state_slots = 1 + num_spec_tokens
spec_state_offsets = (
torch.remainder(cache_seqlens, state_slots),
torch.remainder(kv_seqlens, state_slots),
)

range_idx = torch.arange(
-conv_kernel_size,
0,
device=cache_seqlens.device,
dtype=torch.int32,
)
# Read the (conv_kernel_size - 1) tokens preceding the current write
# window from the circular buffer.
read_conv_offsets = torch.remainder(
cache_seqlens[:, None] + range_idx[1:][None],
conv_state_len,
).to(torch.int64)
# Write the last conv_kernel_size tokens of this prefill batch into
# circular-buffer slots so the next decode read aligns naturally.
write_conv_offsets = torch.remainder(
kv_seqlens[:, None] + range_idx[None],
conv_state_len,
).to(torch.int64)
spec_conv_offsets = (read_conv_offsets, write_conv_offsets)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
is_decoding,
step_context.block_offsets,
# cu_seqlens is only used in GDN and is passed down via q_start_loc.
# Otherwise, q_start_loc is None.
Expand All @@ -406,6 +454,10 @@ def get_moe_group_name(group):
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
has_initial_state=has_initial_state,
is_multi_token_decoding=is_multi_token_decoding,
spec_conv_offsets=spec_conv_offsets,
spec_state_offsets=spec_state_offsets,
cache_seqlens=cache_seqlens,
)
step_context.attn_metadata = attn_metadata

Expand Down Expand Up @@ -462,6 +514,12 @@ def init():
logger.warning(f'Error during Ascend initialization: {str(e)}. '
'Please check your Ascend environment configuration.')

try:
import dlinfer.framework.lmdeploy_ext.device # noqa: F401 — triggers vendor_device_init()
except ImportError:
logger.warning('dlinfer framework extensions not found. '
'Ascend-specific model patches will not be applied.')

try:
from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton
init_device_properties_triton()
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class DlinferAttentionMetadata(AttentionMetadata):
quant_meta: dict = None
cu_seq_lens_kv: Tensor | None = None
has_initial_state: Tensor | None = None
is_multi_token_decoding: bool = False
spec_conv_offsets: Sequence[Tensor] = tuple()
spec_state_offsets: Sequence[Tensor] = tuple()
cache_seqlens: Tensor | None = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def from_config(
target_cache_cfg: CacheConfig,
target_model: str = None,
dtype: str = 'auto',
device_type: str = 'auto',
):
model = model or target_model
model_config = ModelConfig.from_pretrained(model,
Expand All @@ -574,6 +575,7 @@ def from_config(
is_draft_model=True,
spec_method=method,
block_size=target_cache_cfg.block_size,
device_type=device_type,
)
cache_config = None
# include medusa
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ def build_specdecode_config(target_model, speculative_config: SpeculativeConfig,
target_model=target_model,
target_cache_cfg=cache_config,
dtype=engine_config.dtype,
device_type=engine_config.device_type,
)
return specdecode_config