Skip to content

MLA PS mode add metadata split reference code#2177

Merged
shengnxu merged 6 commits intomainfrom
mmd/dev/mla_ps_split_ref
Mar 10, 2026
Merged

MLA PS mode add metadata split reference code#2177
shengnxu merged 6 commits intomainfrom
mmd/dev/mla_ps_split_ref

Conversation

@minmengdie
Copy link
Copy Markdown
Contributor

@minmengdie minmengdie commented Mar 4, 2026

Motivation

MLA PS mode add metadata split reference code

Technical Details

Test Plan

python3 op_tests/test_mla_persistent.py -c 512 -b 1 -n 16,4

Test Result

image image

Submission Checklist

@minmengdie minmengdie requested review from a team and Copilot March 4, 2026 11:21
@minmengdie minmengdie force-pushed the mmd/dev/mla_ps_split_ref branch from bceedea to 3214a17 Compare March 4, 2026 11:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the MLA persistent-mode test harness with a new reference path that can split KV work using generated MLA metadata and then reduce partial tiles back into the final output, enabling validation of “metadata split” behavior in PS/persistent decode mode.

Changes:

  • Extend ref_masked_attention to support a configurable causal mask diagonal for split-chunk attention.
  • Add Torch reference implementations for KV-split extend (torch_mla_extend_split_kv) and a Python reduction (torch_mla_reduce_v1), plus an end-to-end helper (torch_mla_split_kv_and_reduce).
  • Update decode tests to request logits from mla_decode_fwd and compare split/reduced Torch reference results against the ASM path in persistent mode.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread op_tests/test_mla_persistent.py Outdated
Comment thread op_tests/test_mla_persistent.py
Comment thread op_tests/test_mla_persistent.py
Comment thread op_tests/test_mla_persistent.py Outdated
Comment thread op_tests/test_mla_persistent.py Outdated
Comment thread op_tests/test_mla_persistent.py
@minmengdie minmengdie force-pushed the mmd/dev/mla_ps_split_ref branch from 3214a17 to 4fd018b Compare March 6, 2026 01:55
@minmengdie minmengdie force-pushed the mmd/dev/mla_ps_split_ref branch from f84ea5d to 070569e Compare March 10, 2026 06:39
@shengnxu shengnxu merged commit 434ee19 into main Mar 10, 2026
24 checks passed
@shengnxu shengnxu deleted the mmd/dev/mla_ps_split_ref branch March 10, 2026 11:05
valarLip pushed a commit that referenced this pull request Mar 18, 2026
* mla ps add metadata split reference test code

* fix test_mla_persistent.py

* fix conflict

* add 3buffer split kv ref code

* add gfx condition

* add test
AMD-yanfeiwang pushed a commit to AMD-yanfeiwang/aiter that referenced this pull request Mar 18, 2026
* mla ps add metadata split reference test code

* fix test_mla_persistent.py

* fix conflict

* add 3buffer split kv ref code

* add gfx condition

* add test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants