Skip to content

[plugin][MLA] optimize MLA metadata build and remove D2D copy#387

Open
zejunchen-zejun wants to merge 7 commits intomainfrom
zejun/opt_plugin_build_mla
Open

[plugin][MLA] optimize MLA metadata build and remove D2D copy#387
zejunchen-zejun wants to merge 7 commits intomainfrom
zejun/opt_plugin_build_mla

Conversation

@zejunchen-zejun
Copy link
Contributor

@zejunchen-zejun zejunchen-zejun commented Mar 23, 2026

This PR improves MLA decode performance in ATOM's vLLM plugin mode by reducing per-step metadata work, using direct KV index generation, and removing a redundant positions device-to-device copy

For performance:

Model Concurrency Baseline (total token/s) this PR (total token/s) Gain
Kimi-K2 TP8 4 666.38 678.63 +1.84%
Kimi-K2 TP8 32 3916.37 3967.31 +1.30%
DeepSeek-R1 FP8 TP8 4 630.41 641.05 +1.69%
DeepSeek-R1 FP8 TP8 32 3270.91 3310.07 +1.20%

For accuracy:

Model Baseline this PR Delta
Kimi-K2 TP8 0.9325 0.9401 +0.0076
DeepSeek-R1 FP8 TP8 0.9492 0.9409 -0.0083

Copilot AI review requested due to automatic review settings March 23, 2026 07:04
@zejunchen-zejun zejunchen-zejun marked this pull request as draft March 23, 2026 07:05
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes MLA (Multi-head Latent Attention) execution in vLLM plugin mode by reducing per-step metadata overhead and avoiding unnecessary device-to-device copies, with a new env toggle to control persistent decode metadata behavior.

Changes:

  • Add ATOM_USE_PERSISTENT_MLA_DECODE_METADATA env flag to enable/disable persistent MLA decode metadata buffers.
  • Thread runtime positions through forward_context so MLA decode can consume them without extra D2D copies.
  • Optimize MLA decode scheduling by generating paged_kv_indices via a Triton kernel, using in-place cumsum/arange, and reusing decode buffers/imports.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
atom/utils/envs.py Adds env toggle for persistent MLA decode metadata.
atom/plugin/vllm/model_wrapper.py Writes runtime positions into the forward context (fallback to static context).
atom/plugin/attention_mla.py Reuses decode buffers, caches vLLM imports, and switches decode path selection based on persistent metadata enablement.
atom/plugin/attention.py Optimizes MLA decode metadata build (in-place indptr, Triton kv index generation, optional persistent worker buffers).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

D2D copy and use kv_indices_generate_triton

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
@zejunchen-zejun zejunchen-zejun force-pushed the zejun/opt_plugin_build_mla branch from e1a0fe3 to ba4d185 Compare March 25, 2026 01:58
@wuhuikx wuhuikx marked this pull request as ready for review March 25, 2026 04:39
@wuhuikx wuhuikx changed the title [draft][plugin][MLA] optimize MLA metadata build and remove D2D copy [plugin][MLA] optimize MLA metadata build and remove D2D copy Mar 25, 2026
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Copilot AI review requested due to automatic review settings March 25, 2026 05:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

avoid a D2D copy

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Copilot AI review requested due to automatic review settings March 25, 2026 06:20
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@wuhuikx
Copy link
Contributor

wuhuikx commented Mar 25, 2026

Hi @ZhangLirong-amd and @XiaobingSuper could you please help take a look at this PR? It has performance benefits on DS and Kimi. Even if the code changes are small, but it's in the MLA critical path and you'd better help review it.

Copilot AI review requested due to automatic review settings March 25, 2026 14:48
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +793 to +797
max_qo_len = (
(query_start_loc_cpu[-1] - query_start_loc_cpu[-2]).item()
if query_start_loc_cpu.numel() > 1
else 1
)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_qo_len is now computed using only the last segment length (query_start_loc_cpu[-1] - query_start_loc_cpu[-2]), which underestimates the true max when per-request decode query lengths vary (i.e., when num_decode_tokens != num_reqs). This can lead to incorrect max_qo_len passed into mla_decode_fwd and potential workspace/shape issues. Compute max_qo_len as the maximum of all per-request query lengths (e.g., based on query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) instead of just the last one.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants