[plugin][MLA] optimize MLA metadata build and remove D2D copy#387
[plugin][MLA] optimize MLA metadata build and remove D2D copy#387zejunchen-zejun wants to merge 7 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR optimizes MLA (Multi-head Latent Attention) execution in vLLM plugin mode by reducing per-step metadata overhead and avoiding unnecessary device-to-device copies, with a new env toggle to control persistent decode metadata behavior.
Changes:
- Add
ATOM_USE_PERSISTENT_MLA_DECODE_METADATAenv flag to enable/disable persistent MLA decode metadata buffers. - Thread runtime
positionsthroughforward_contextso MLA decode can consume them without extra D2D copies. - Optimize MLA decode scheduling by generating
paged_kv_indicesvia a Triton kernel, using in-placecumsum/arange, and reusing decode buffers/imports.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
atom/utils/envs.py |
Adds env toggle for persistent MLA decode metadata. |
atom/plugin/vllm/model_wrapper.py |
Writes runtime positions into the forward context (fallback to static context). |
atom/plugin/attention_mla.py |
Reuses decode buffers, caches vLLM imports, and switches decode path selection based on persistent metadata enablement. |
atom/plugin/attention.py |
Optimizes MLA decode metadata build (in-place indptr, Triton kv index generation, optional persistent worker buffers). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
D2D copy and use kv_indices_generate_triton Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
e1a0fe3 to
ba4d185
Compare
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
avoid a D2D copy Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @ZhangLirong-amd and @XiaobingSuper could you please help take a look at this PR? It has performance benefits on DS and Kimi. Even if the code changes are small, but it's in the MLA critical path and you'd better help review it. |
for kimi-k2 Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| max_qo_len = ( | ||
| (query_start_loc_cpu[-1] - query_start_loc_cpu[-2]).item() | ||
| if query_start_loc_cpu.numel() > 1 | ||
| else 1 | ||
| ) |
There was a problem hiding this comment.
max_qo_len is now computed using only the last segment length (query_start_loc_cpu[-1] - query_start_loc_cpu[-2]), which underestimates the true max when per-request decode query lengths vary (i.e., when num_decode_tokens != num_reqs). This can lead to incorrect max_qo_len passed into mla_decode_fwd and potential workspace/shape issues. Compute max_qo_len as the maximum of all per-request query lengths (e.g., based on query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) instead of just the last one.
This PR improves MLA decode performance in ATOM's vLLM plugin mode by reducing per-step metadata work, using direct KV index generation, and removing a redundant positions device-to-device copy
For performance:
For accuracy: