Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 77 additions & 26 deletions dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update():


# AscendCudaGraphMixin methods for cudagraph buffer management.
def AscendCudaGraphMixin_support_cuda_graph(
self,
input_ids: Tensor,
position_ids: Tensor,
past_key_values: List[List[Tensor]],
attn_metadata: Any = None,
inputs_embeds: Tensor = None,
**kwargs,
):
"""Allow multi-token decode graph only when runtime length updates exist."""
if attn_metadata is None:
return False

is_decoding = getattr(attn_metadata, "is_decoding", False)
is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False)
if is_multi_token and not aclgraph_use_torch_npu_update():
return False
return is_decoding or is_multi_token


def AscendCudaGraphMixin_make_buffers_cudagraph(
self, graph_meta: CudaGraphMeta, *args, **kwargs
) -> BuffType:
Expand All @@ -58,9 +78,7 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
(max_batches, num_blocks), dtype=torch.int32, device=device
)

input_buffers["q_seqlens"] = torch.ones(
max_batches, dtype=torch.int32, device=device
)
input_buffers["q_seqlens"] = torch.ones(max_batches, dtype=torch.int32)

input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32)

Expand All @@ -69,18 +87,23 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
)

input_buffers["kv_start_indices"] = -torch.ones(
(max_batches), dtype=torch.int32, device=device
(max_tokens), dtype=torch.int32, device=device
)

input_buffers["x_active_mask"] = torch.zeros(
(max_batches), dtype=torch.bool, device=device
(max_tokens), dtype=torch.bool, device=device
)

input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1)

# ssm
if graph_meta.is_ssm:
input_buffers["state_ids"] = torch.full(
(max_batches,), -1, dtype=torch.int64, device=device
)
input_buffers["cache_seqlens"] = torch.zeros(
max_batches, dtype=torch.int32, device=device
)

# mrope
if graph_meta.use_mrope:
Expand Down Expand Up @@ -108,11 +131,13 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
moe_metadata = get_step_ctx_manager().current_context().moe_metadata
x_active_mask: Tensor = moe_metadata.x_active_mask
q_start_loc: Tensor = attn_metadata.q_start_loc
cache_seqlens: Tensor = attn_metadata.cache_seqlens

input_buffers: BuffType = graph_meta.input_buffers

batch_size, num_blocks = block_offsets.size()
num_tokens = input_ids.size(-1)
q_seqlens: Tensor = attn_metadata.q_seqlens

# fill buffer
max_num_tokens = input_buffers["input_ids"].size(-1)
Expand All @@ -126,22 +151,32 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
input_buffers["position_ids"][:, :num_tokens] = position_ids
input_buffers["block_offsets"].zero_()
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
input_buffers["q_seqlens"].fill_(0)
input_buffers["q_seqlens"][: batch_size] = q_seqlens
input_buffers["kv_seqlens"].fill_(0)
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
input_buffers["kv_start_indices"].fill_(-1)
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices
if x_active_mask is not None:
input_buffers["x_active_mask"].fill_(0)
input_buffers["x_active_mask"][:batch_size] = x_active_mask
input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask

# ssm
if graph_meta.is_ssm:
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1]
bs = input_buffers["q_start_loc"].size(0)
max_q_seq_len = attn_metadata.max_q_seq_len
padding_tensor = torch.arange(0, bs) * max_q_seq_len
input_buffers["q_start_loc"].copy_(padding_tensor)
input_buffers["q_start_loc"][:q_start_loc.size(0)] = q_start_loc

state_ids = kwargs["state_ids"]
input_buffers["state_ids"].fill_(-1)
input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids)
input_buffers["state_ids"].fill_(0)
input_buffers["state_ids"][: batch_size].copy_(state_ids)

input_buffers["cache_seqlens"].fill_(0)
input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens)

attn_metadata.cache_seqlens = input_buffers["cache_seqlens"]
attn_metadata.attention_mask = [input_buffers["attention_mask"]]

if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
Expand All @@ -151,10 +186,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
1, max_num_tokens, emb_size
)
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
# create inputs
# Use compatible size but cap at graph's max_batchs to avoid buffer overflow
new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs)

attn_metadata.q_seqlens = input_buffers["q_seqlens"]
attn_metadata.block_offsets = input_buffers["block_offsets"]
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"]
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"]
Expand All @@ -175,7 +207,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(

new_inputs.update(kwargs)

# ssm: override kwargs' variable-length state_ids with the fixed-size buffer
if graph_meta.is_ssm:
new_inputs["state_ids"] = input_buffers["state_ids"]

Expand Down Expand Up @@ -209,6 +240,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
context.mrope_position_ids = input_buffers["mrope_position_ids"]


CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph
CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph
Expand Down Expand Up @@ -358,7 +390,7 @@ def forward(self, **kwargs):
]
)
else:
update_attn_params(self.update_stream, self.meta, self.max_tokens)
update_attn_params(self.update_stream, self.meta, self.max_batches)
self._graph.replay()
output_buffers = self.meta.output_buffers
output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
Expand Down Expand Up @@ -427,19 +459,33 @@ def _get_capture_tokens(self, batch_size: int):
def get_graph_key(
self,
input_ids: torch.Tensor,
attn_metadata: Any,
**kwargs,
):
"""Get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
is_decoding = attn_metadata.is_decoding
is_multi_token_decoding = attn_metadata.is_multi_token_decoding
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch

if is_multi_token_decoding:
q_seqlens = attn_metadata.q_seqlens
max_q_seq_len = attn_metadata.max_q_seq_len
batch_size = q_seqlens.size(0)
if meta.padding_batch_size is None:
new_batch_size = self._get_capture_tokens(batch_size)
else:
padding_num_tokens = meta.padding_batch_size
padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len
new_batch_size = self._get_capture_tokens(padding_batch_size)
return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len)

num_tokens = input_ids.numel()
if meta.padding_batch_size is None:
new_num_tokens = self._get_capture_tokens(num_tokens)
else:
new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
return (new_num_tokens, is_decoding, enable_microbatch)
return (new_num_tokens, is_decoding, enable_microbatch, 1)

def __call__(self, **kwargs):
"""call."""
Expand All @@ -451,16 +497,21 @@ def __call__(self, **kwargs):
return self.model.make_output_buffers(ret)

graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
is_decoding = graph_key[1]
max_batches = graph_key[0]
is_decoding_or_multi_token_decoding = graph_key[1]
max_q_seq_len = graph_key[3]
if is_decoding_or_multi_token_decoding:
max_tokens = max_batches * max_q_seq_len
else:
max_tokens = max_batches
max_batches = self.max_batches
if graph_key not in self._runner_map:
max_batches = max_tokens if is_decoding else self.max_batches
runner = AscendSingleGraphRunner(
self.model,
max_batches=max_batches,
max_tokens=max_tokens,
num_blocks=self.num_blocks,
is_decoding=is_decoding,
is_decoding=is_decoding_or_multi_token_decoding,
pool=self.graph_pool_handle,
model_config=self.model_config,
device=self.device,
Expand Down
Loading