Dsv4 sparse indexer#2998
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Adds Triton implementations for DeepSeek-V4 (DSv4) sparse attention and Indexer top-k selection to replace slow Torch fallbacks in ATOM/serving paths.
Changes:
- Introduce
sparse_mqa_sinkTriton op implementing DSv4 sparse MQA forward with attention-sink denominator semantics. - Introduce
dsv4_indexer_topkTriton op implementing DSv4 Indexer scoring + causal top-k, including a dense causal fast path. - Add unit tests for both new ops and register the modules in Triton backward-compat import map.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_sparse_mqa_sink.py | Adds correctness test comparing sparse_mqa_sink vs a Torch reference. |
| op_tests/test_dsv4_indexer.py | Adds tests for Indexer dense-causal fast path and scored top-k vs Torch reference. |
| aiter/ops/triton/attention/sparse_mqa_sink.py | Python wrapper for launching the sparse MQA sink Triton kernel. |
| aiter/ops/triton/attention/dsv4_indexer.py | Python wrapper for Indexer scoring + top-k, including dense fast path. |
| aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py | Triton kernel for sparse MQA sink with per-token top-k gather and sink denominator. |
| aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py | Triton kernels for dense causal indices, scoring, and finalizing offset indices. |
| aiter/ops/triton/init.py | Registers dsv4_indexer and sparse_mqa_sink for backward-compatible imports. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
# Conflicts: # aiter/ops/topk.py # csrc/include/rocm_ops.hpp # csrc/include/topk_per_row.h # csrc/kernels/topk_per_row_kernels.cu # op_tests/test_topk_per_row.py
|
sorry, @vgokhale , the previous changes worked but was an underlying issue with mHC and Investigating now |
|
@Oseltamivir that's fine. We are developing our own implementations so this will change eventually anyways. |
|
@vgokhale In that case, please merge if you think it's ok, else, please close. I don't want to hamper Aiter/ATOM work. the current issue is poor evals(~0.5 on gsm8k) at higher concurrencies. At conc=1 accuracy is fine. Will continue to investigate. |
# Conflicts: # csrc/kernels/mhc_kernels.cu
Motivation
DSv4 uses a sparse attention path where each query gathers a small top-k set of compressed KV entries, plus an Indexer path that scores compressed KV entries to produce those top-k indices.
The current ATOM DSv4 integration has correctness-first Torch fallbacks for both paths. Those fallbacks materialize large intermediate tensors and are too slow for serving, especially at
conc > 1. This PR adds AITER Triton kernels for the DSv4 sparse MQA attention sink path and the DSv4 indexer scorer/top-k path so ATOM can avoid the Torch fallback.Technical Details
This PR adds:
sparse_mqa_sink: DSv4 sparse MQA forward with attention-sink denominator semantics.dsv4_indexer_topk: DSv4 Indexer scorer and causal top-k selection without materializing the Torch fallback’s[tokens, heads, committed_kv]score tensor.actual_topk == n_committed, which is common forshort-context DSv4 serving.
The sparse attention kernel supports DSv4’s MQA layout:
q:[num_tokens, num_heads, head_dim]kv:[num_blocks, block_size, head_dim]topk_indices:[num_tokens, topk]attn_sink:[num_heads]The Indexer kernel computes:
score[t, k] = sum_h relu(q[t, h] @ kv[k]) * weights[t, h]then applies the DSv4 causal compressed-token mask and returns offset top-k indices for the downstream sparse attention gather.Relevant downstream integration target: ROCm/ATOM DeepSeek-V4 PR650.
Test Plan
Test Result
Local syntax/import validation passed with:
The branch is clean against current ROCm/aiter:main and contains only the DSv4 sparse/indexer kernel additions plus tests.
Was tested and is being used at SemiAnalysisAI/InferenceX #1229, with runs: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/25193385172
op_tests results: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/25221896798
Submission Checklist
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.