dsv4: refactor decode qkv_proj_rope with pl.spmd and scope fusion#385
Conversation
…nd constant inlining - Fuse attn_norm_rms + attn_norm_apply, qr_rms_norm + qr_norm_apply, qr_quant_amax + qr_quant_apply, qproj_matmul + qproj_dequant, kv_rms + kv_norm_nope scopes - T-tile parallelism for attn_norm and qr_rms_norm (per-token reduce) - qproj merged scope splits dequant on T_TILE to keep cube+vec under Vec UB - Halve task counts for qr_proj_matmul and qproj - Inline *_BLOCKS, _GROUP, and stale conditional CHUNK expressions - Rename *_CHUNK constants to *_TILE for consistency - Drop chunked_loop_optimizer and partial-sum scaffolding (Opt S/U)
- Convert all parallel scopes (attn_norm, qr_proj_matmul, qr_rms_norm, qproj, q_head_rms_nope, q_head_rope, q_rope_reassemble, q_rope_write, kv_proj_matmul) from pl.parallel + pl.at to pl.spmd dispatch - Halve kv_proj_matmul task count (group=2 inner pl.range) - Fuse qr_rms_norm + qr_quant into one T-tiled spmd scope using a two-pass design: pass 1 computes amax without GM staging, pass 2 recomputes norm and quantizes; drops qr_bf16 and qr_scale_dq GM intermediates and keeps kernel outputs at 2 to sidestep the pypto multi-InOut OptimizeOrchTensors alias bug - qproj_dequant reads qr_scale directly (same values as qr_scale_dq) - Split q_rope reassemble/write into two spmd scopes sharing a GM staging tensor (cube+vec mixing in one scope blows Vec UB)
- Drop stage labels, tuning-log references, and bug-rationale notes; keep only short, code-functional comments - Remove table-aligned constants and tensor signatures - Flatten wq_b_scale from [H_BLOCKS, OUT_TILE] to [H * HEAD_DIM] so the external shape no longer depends on internal tiling; reshape per-tile at the qproj use site - Move q_rope_pair_stage allocation next to its first writer - Drop a stale assert (H * HEAD_DIM) % (HEAD_TILE * 8) (recombines unrelated constraints already implied by the loop bounds)
- Halve T_TILE (16 -> 8) to double attn_norm and qr_rms_norm_quant SPMD block counts; introduce QPROJ_T_TILE=16 for the qproj dequant T loop (cube innerRows alignment) and KV_RMS_T_TILE=16 for the new kv_rms_norm SPMD scope - Convert kv_rms_norm from a single pl.at scope into an 8-way SPMD scope over T (per-token reduction) - Convert remaining pl.range loops inside SPMD / pl.at to pl.pipeline(stage=2); keep pl.range on q_head_rms_nope and q_head_rope outer h_inner loops where pipelining would blow Vec UB - Inline ROPE_PAIR_TILE as 32 (no longer derived)
Match the qkv_proj_rope signature change: drop the
[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] 2D shape and the corresponding
.view() at TensorSpec build time. Updated decode_attention_{csa,hca,swa}
and decode_{csa,hca,swa}. Draft callers (*_draft.py) left untouched.
Disambiguates the decode-side kernel from the new prefill_qkv_proj_rope module. Updates all `from qkv_proj_rope import ...` callers (3 decode attention files, 1 prefill draft, and prefill_qkv_proj_rope itself). Function names (qkv_proj_rope, golden_qkv_proj_rope) are unchanged.
|
Warning Review limit reached
More reviews will be available in 46 minutes and 30 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
📝 WalkthroughWalkthroughThis PR replaces the legacy ChangesDeepSeek-V4 decode QKV projection kernel refactor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the DeepSeek-V4 decode attention kernels by flattening the wq_b_scale tensor from a 2D shape to a 1D shape ([H * HEAD_DIM]), simplifying the scale layout. It also introduces a new file decode_qkv_proj_rope.py to handle the single-token decode projection and RoPE fusion. The code reviewer identified several critical issues in the new file, including a data race caused by defining col_acc outside a pipelined loop, and multiple Read-After-Write (RAW) hazards on reduction variables (x_sq_sum, qr_sq_sum, qr_tile_amax, q_head_sq_sum, and kv_sq_sum) due to incorrect loop pipelining. Additionally, a high-severity shape mismatch was noted in prefill_qkv_proj_rope.py because it imports the updated decode tensor specs where the scale tensor is now 1D instead of 2D.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
models/deepseek/v4/prefill_qkv_proj_rope.py (2)
396-396: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win2D indexing requires shape contract alignment.
If adopting the decode module's flat 1D
wq_b_scale: [H * HEAD_DIM]shape, this 2D indexing must be updated to flat indexing. The equivalent flat index would be:(hbg + h_inner) * Q_PROJ_OUT_CHUNK : (hbg + h_inner + 1) * Q_PROJ_OUT_CHUNK.♻️ Proposed fix if adopting flat 1D shape
- w_scale = wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] + flat_idx = (hbg + h_inner) * Q_PROJ_OUT_CHUNK + w_scale = pl.reshape( + wq_b_scale[flat_idx : flat_idx + Q_PROJ_OUT_CHUNK], + [1, Q_PROJ_OUT_CHUNK] + )Also update the signatures at lines 92 and 507:
- wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_qkv_proj_rope.py` at line 396, The code is using 2D slicing on wq_b_scale but the buffer is flat 1D; replace the 2D slice wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] with a flat slice using Q_PROJ_OUT_CHUNK, i.e. use indices (hbg + h_inner) * Q_PROJ_OUT_CHUNK : (hbg + h_inner + 1) * Q_PROJ_OUT_CHUNK to extract w_scale, and update any functions that declare or accept wq_b_scale to expect a 1D array (adjust the signature and any callers that pass wq_b_scale accordingly—the two places that currently declare/accept wq_b_scale should be changed to the 1D contract).
630-646: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winOverride
wq_b_scalespec if prefill maintains 2D shape contract.The code passes through specs from
_build_qkv_tensor_specs()without overridingwq_b_scale. If the decode module produces flat 1D but this prefill kernel must maintain 2D shape (due to implementation at line 396), add an override here to reconstruct the 2D spec.♻️ Proposed fix to override wq_b_scale spec for 2D shape
elif spec.name == "qr_scale": specs.append(TensorSpec("qr_scale", [T, 1], torch.float32, is_output=True)) + elif spec.name == "wq_b_scale": + # Prefill kernel uses 2D shape, override if decode spec is flat 1D + specs.append(TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32)) else: specs.append(spec)Note: This override is only necessary if prefill keeps 2D shape while decode uses 1D. If prefill adopts flat 1D, update the kernel implementation instead (see comment on line 396).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_qkv_proj_rope.py` around lines 630 - 646, The loop over specs from _build_qkv_tensor_specs() fails to override the wq_b_scale spec when prefill requires a 2D shape; update the loop to detect spec.name == "wq_b_scale" and replace it with a TensorSpec that reconstructs the 2D shape expected by the prefill kernel (for example TensorSpec("wq_b_scale", [T, HEAD_DIM], same dtype as original) while preserving the original init_value and is_output flags); this ensures the prefill implementation (the kernel that assumes a 2D contract) receives the correct shape instead of the unmodified 1D spec.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/prefill_attention_swa_draft.py`:
- Line 64: The prefill_attention_swa function still expects a 2D wq_b_scale but
the decode-side spec now provides a flattened 1D wq_b_scale of shape [H *
HEAD_DIM]; update prefill_attention_swa to accept the 1D array and reshape it
internally to [H, HEAD_DIM] (or broadcast/reshape as needed) before using it, or
add a short validation/reshape step that detects 1D input and transforms it into
the 2D layout; refer to the prefill_attention_swa function and the wq_b_scale
parameter name (and any uses of HEAD_DIM / H inside that function) to implement
the change so argument binding matches the flattened contract.
---
Outside diff comments:
In `@models/deepseek/v4/prefill_qkv_proj_rope.py`:
- Line 396: The code is using 2D slicing on wq_b_scale but the buffer is flat
1D; replace the 2D slice wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] with a
flat slice using Q_PROJ_OUT_CHUNK, i.e. use indices (hbg + h_inner) *
Q_PROJ_OUT_CHUNK : (hbg + h_inner + 1) * Q_PROJ_OUT_CHUNK to extract w_scale,
and update any functions that declare or accept wq_b_scale to expect a 1D array
(adjust the signature and any callers that pass wq_b_scale accordingly—the two
places that currently declare/accept wq_b_scale should be changed to the 1D
contract).
- Around line 630-646: The loop over specs from _build_qkv_tensor_specs() fails
to override the wq_b_scale spec when prefill requires a 2D shape; update the
loop to detect spec.name == "wq_b_scale" and replace it with a TensorSpec that
reconstructs the 2D shape expected by the prefill kernel (for example
TensorSpec("wq_b_scale", [T, HEAD_DIM], same dtype as original) while preserving
the original init_value and is_output flags); this ensures the prefill
implementation (the kernel that assumes a 2D contract) receives the correct
shape instead of the unmodified 1D spec.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: d97415d5-46c8-4f25-9172-144af6f43e5f
📒 Files selected for processing (10)
models/deepseek/v4/decode_attention_csa.pymodels/deepseek/v4/decode_attention_hca.pymodels/deepseek/v4/decode_attention_swa.pymodels/deepseek/v4/decode_csa.pymodels/deepseek/v4/decode_hca.pymodels/deepseek/v4/decode_qkv_proj_rope.pymodels/deepseek/v4/decode_swa.pymodels/deepseek/v4/prefill_attention_swa_draft.pymodels/deepseek/v4/prefill_qkv_proj_rope.pymodels/deepseek/v4/qkv_proj_rope.py
💤 Files with no reviewable changes (1)
- models/deepseek/v4/qkv_proj_rope.py
The decode build_tensor_specs now produces wq_b_scale as 1D [H * HEAD_DIM]; prefill imports those specs but kept the old 2D [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] signature, which made shapes disagree at jit time. Flatten the prefill signature too and slice + reshape at the dequant use site.
Drop the obsolete _draft file. prefill_attention_swa_draft now imports golden_prefill_sparse_attn from the non-draft prefill_sparse_attn module.
Summary
qkv_proj_rope.pytodecode_qkv_proj_rope.pyto disambiguate from the new prefill kernelpl.spmd; convert most innerpl.rangetopl.pipeline(stage=2)wq_b_scaleto[H * HEAD_DIM]and propagate to all 6 decode-side callers*_BLOCKS/*_GROUPconstants, table-aligned formatting, stage-label comments, and tuning-log references