Skip to content

dsv4: refactor decode qkv_proj_rope with pl.spmd and scope fusion#385

Merged
zhangqi-chen merged 8 commits into
hw-native-sys:mainfrom
zhangqi-chen:dsv4-qkv-proj-rope-refactor
May 26, 2026
Merged

dsv4: refactor decode qkv_proj_rope with pl.spmd and scope fusion#385
zhangqi-chen merged 8 commits into
hw-native-sys:mainfrom
zhangqi-chen:dsv4-qkv-proj-rope-refactor

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

Summary

  • Rename qkv_proj_rope.py to decode_qkv_proj_rope.py to disambiguate from the new prefill kernel
  • Fuse attn_norm, qr_rms_norm + qr_quant, kv_rms + kv_norm_nope, qproj_matmul + qproj_dequant scopes (qproj uses T-tile to keep cube+vec under Vec UB)
  • Convert all parallel scopes to pl.spmd; convert most inner pl.range to pl.pipeline(stage=2)
  • T-tile parallelism for attn_norm / qr_rms_norm_quant / kv_rms_norm (per-token reduction)
  • Flatten wq_b_scale to [H * HEAD_DIM] and propagate to all 6 decode-side callers
  • Drop stale *_BLOCKS / *_GROUP constants, table-aligned formatting, stage-label comments, and tuning-log references

…nd constant inlining

- Fuse attn_norm_rms + attn_norm_apply, qr_rms_norm + qr_norm_apply,
  qr_quant_amax + qr_quant_apply, qproj_matmul + qproj_dequant,
  kv_rms + kv_norm_nope scopes
- T-tile parallelism for attn_norm and qr_rms_norm (per-token reduce)
- qproj merged scope splits dequant on T_TILE to keep cube+vec under Vec UB
- Halve task counts for qr_proj_matmul and qproj
- Inline *_BLOCKS, _GROUP, and stale conditional CHUNK expressions
- Rename *_CHUNK constants to *_TILE for consistency
- Drop chunked_loop_optimizer and partial-sum scaffolding (Opt S/U)
- Convert all parallel scopes (attn_norm, qr_proj_matmul, qr_rms_norm,
  qproj, q_head_rms_nope, q_head_rope, q_rope_reassemble, q_rope_write,
  kv_proj_matmul) from pl.parallel + pl.at to pl.spmd dispatch
- Halve kv_proj_matmul task count (group=2 inner pl.range)
- Fuse qr_rms_norm + qr_quant into one T-tiled spmd scope using a
  two-pass design: pass 1 computes amax without GM staging, pass 2
  recomputes norm and quantizes; drops qr_bf16 and qr_scale_dq GM
  intermediates and keeps kernel outputs at 2 to sidestep the pypto
  multi-InOut OptimizeOrchTensors alias bug
- qproj_dequant reads qr_scale directly (same values as qr_scale_dq)
- Split q_rope reassemble/write into two spmd scopes sharing a GM
  staging tensor (cube+vec mixing in one scope blows Vec UB)
- Drop stage labels, tuning-log references, and bug-rationale notes;
  keep only short, code-functional comments
- Remove table-aligned constants and tensor signatures
- Flatten wq_b_scale from [H_BLOCKS, OUT_TILE] to [H * HEAD_DIM] so the
  external shape no longer depends on internal tiling; reshape per-tile
  at the qproj use site
- Move q_rope_pair_stage allocation next to its first writer
- Drop a stale assert (H * HEAD_DIM) % (HEAD_TILE * 8) (recombines
  unrelated constraints already implied by the loop bounds)
- Halve T_TILE (16 -> 8) to double attn_norm and qr_rms_norm_quant SPMD
  block counts; introduce QPROJ_T_TILE=16 for the qproj dequant T loop
  (cube innerRows alignment) and KV_RMS_T_TILE=16 for the new
  kv_rms_norm SPMD scope
- Convert kv_rms_norm from a single pl.at scope into an 8-way SPMD
  scope over T (per-token reduction)
- Convert remaining pl.range loops inside SPMD / pl.at to
  pl.pipeline(stage=2); keep pl.range on q_head_rms_nope and
  q_head_rope outer h_inner loops where pipelining would blow Vec UB
- Inline ROPE_PAIR_TILE as 32 (no longer derived)
Match the qkv_proj_rope signature change: drop the
[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] 2D shape and the corresponding
.view() at TensorSpec build time. Updated decode_attention_{csa,hca,swa}
and decode_{csa,hca,swa}. Draft callers (*_draft.py) left untouched.
Disambiguates the decode-side kernel from the new prefill_qkv_proj_rope
module. Updates all `from qkv_proj_rope import ...` callers (3 decode
attention files, 1 prefill draft, and prefill_qkv_proj_rope itself).
Function names (qkv_proj_rope, golden_qkv_proj_rope) are unchanged.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

Warning

Review limit reached

@zhangqi-chen, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 46 minutes and 30 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: f870792a-44cc-407a-9f58-2f29ac9cb445

📥 Commits

Reviewing files that changed from the base of the PR and between dc2e1e0 and 34d5215.

📒 Files selected for processing (3)
  • models/deepseek/v4/prefill_attention_swa_draft.py
  • models/deepseek/v4/prefill_qkv_proj_rope.py
  • models/deepseek/v4/prefill_sparse_attn_draft.py
📝 Walkthrough

Walkthrough

This PR replaces the legacy qkv_proj_rope module with a new decode_qkv_proj_rope implementation and harmonizes the quantization-scale tensor contract across DeepSeek-V4 decode stacks. The wq_b_scale tensor shape is standardized to flat 1D [H * HEAD_DIM] throughout, with corresponding removal of reshape logic and updated tensor-spec builders across all three attention variants (CSA, HCA, SWA).

Changes

DeepSeek-V4 decode QKV projection kernel refactor

Layer / File(s) Summary
New decode QKV projection fused kernel
models/deepseek/v4/decode_qkv_proj_rope.py
Introduces complete fused RMSNorm + Q-LoRA INT8 quantization + per-head RMSNorm + RoPE rotation kernel for single-token decode, with JIT wrapper, Torch golden reference, randomized tensor specs with fixed RoPE selectors, and CLI test harness with configurable compare tolerances.
CSA decode stack: entry point and orchestrator
models/deepseek/v4/decode_csa.py, models/deepseek/v4/decode_attention_csa.py
Entry-point and orchestrator signatures update wq_b_scale from 2D [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] to flat [H * HEAD_DIM]. Orchestrator imports decode_qkv_proj_rope, removes tensor reshape logic, and updates TensorSpec to emit/expect flat layout.
HCA decode stack: entry point and orchestrator
models/deepseek/v4/decode_hca.py, models/deepseek/v4/decode_attention_hca.py
Entry-point and orchestrator update wq_b_scale shape to flat [H * HEAD_DIM]. Orchestrator imports decode_qkv_proj_rope, removes reshape/view, and updates TensorSpec shape annotations for consistency.
SWA decode stack: entry point and orchestrator
models/deepseek/v4/decode_swa.py, models/deepseek/v4/decode_attention_swa.py
Entry-point and orchestrator update wq_b_scale to flat [H * HEAD_DIM]. Orchestrator imports decode_qkv_proj_rope, removes reshape logic from spec builder, and updates TensorSpec shape to match new contract.
Prefill module reference updates
models/deepseek/v4/prefill_attention_swa_draft.py, models/deepseek/v4/prefill_qkv_proj_rope.py
Prefill modules import golden reference and tensor-spec builder from new decode_qkv_proj_rope location instead of legacy qkv_proj_rope, ensuring consistency with refreshed decode infrastructure.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#332: Directly removes the legacy qkv_proj_rope.py module that this PR refactors, centralizing QKV decode projection under the new decode_qkv_proj_rope implementation.
  • hw-native-sys/pypto-lib#234: Both PRs update DeepSeek-V4 decode stacks around wq_b_scale tensor contracts and integration of fused QKV projection kernels with INT8 quantization.
  • hw-native-sys/pypto-lib#339: Related refactor of the legacy qkv_proj_rope pipeline internals and RoPE even/odd rotation logic that underlies the new decode_qkv_proj_rope implementation.

Poem

🐰 A kernel reborn, sleek and refined,
wq_b_scale flattened, no reshape to find,
DeepSeek's decode now speaks the same tongue,
CSA, HCA, SWA—all three in song!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: refactoring the decode qkv_proj_rope kernel with pl.spmd and scope fusion, which aligns with the primary objective of the PR.
Description check ✅ Passed The description provides specific details about the refactoring work including module renaming, scope fusion, parallel scope conversions, T-tile parallelism, wq_b_scale flattening, and removal of stale code, all matching the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the DeepSeek-V4 decode attention kernels by flattening the wq_b_scale tensor from a 2D shape to a 1D shape ([H * HEAD_DIM]), simplifying the scale layout. It also introduces a new file decode_qkv_proj_rope.py to handle the single-token decode projection and RoPE fusion. The code reviewer identified several critical issues in the new file, including a data race caused by defining col_acc outside a pipelined loop, and multiple Read-After-Write (RAW) hazards on reduction variables (x_sq_sum, qr_sq_sum, qr_tile_amax, q_head_sq_sum, and kv_sq_sum) due to incorrect loop pipelining. Additionally, a high-severity shape mismatch was noted in prefill_qkv_proj_rope.py because it imports the updated decode tensor specs where the scale tensor is now 1D instead of 2D.

Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/decode_qkv_proj_rope.py
Comment thread models/deepseek/v4/prefill_qkv_proj_rope.py
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
models/deepseek/v4/prefill_qkv_proj_rope.py (2)

396-396: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

2D indexing requires shape contract alignment.

If adopting the decode module's flat 1D wq_b_scale: [H * HEAD_DIM] shape, this 2D indexing must be updated to flat indexing. The equivalent flat index would be: (hbg + h_inner) * Q_PROJ_OUT_CHUNK : (hbg + h_inner + 1) * Q_PROJ_OUT_CHUNK.

♻️ Proposed fix if adopting flat 1D shape
-                    w_scale = wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :]
+                    flat_idx = (hbg + h_inner) * Q_PROJ_OUT_CHUNK
+                    w_scale = pl.reshape(
+                        wq_b_scale[flat_idx : flat_idx + Q_PROJ_OUT_CHUNK],
+                        [1, Q_PROJ_OUT_CHUNK]
+                    )

Also update the signatures at lines 92 and 507:

-    wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
+    wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_qkv_proj_rope.py` at line 396, The code is using
2D slicing on wq_b_scale but the buffer is flat 1D; replace the 2D slice
wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] with a flat slice using
Q_PROJ_OUT_CHUNK, i.e. use indices (hbg + h_inner) * Q_PROJ_OUT_CHUNK : (hbg +
h_inner + 1) * Q_PROJ_OUT_CHUNK to extract w_scale, and update any functions
that declare or accept wq_b_scale to expect a 1D array (adjust the signature and
any callers that pass wq_b_scale accordingly—the two places that currently
declare/accept wq_b_scale should be changed to the 1D contract).

630-646: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Override wq_b_scale spec if prefill maintains 2D shape contract.

The code passes through specs from _build_qkv_tensor_specs() without overriding wq_b_scale. If the decode module produces flat 1D but this prefill kernel must maintain 2D shape (due to implementation at line 396), add an override here to reconstruct the 2D spec.

♻️ Proposed fix to override wq_b_scale spec for 2D shape
         elif spec.name == "qr_scale":
             specs.append(TensorSpec("qr_scale", [T, 1], torch.float32, is_output=True))
+        elif spec.name == "wq_b_scale":
+            # Prefill kernel uses 2D shape, override if decode spec is flat 1D
+            specs.append(TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32))
         else:
             specs.append(spec)

Note: This override is only necessary if prefill keeps 2D shape while decode uses 1D. If prefill adopts flat 1D, update the kernel implementation instead (see comment on line 396).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_qkv_proj_rope.py` around lines 630 - 646, The loop
over specs from _build_qkv_tensor_specs() fails to override the wq_b_scale spec
when prefill requires a 2D shape; update the loop to detect spec.name ==
"wq_b_scale" and replace it with a TensorSpec that reconstructs the 2D shape
expected by the prefill kernel (for example TensorSpec("wq_b_scale", [T,
HEAD_DIM], same dtype as original) while preserving the original init_value and
is_output flags); this ensures the prefill implementation (the kernel that
assumes a 2D contract) receives the correct shape instead of the unmodified 1D
spec.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/prefill_attention_swa_draft.py`:
- Line 64: The prefill_attention_swa function still expects a 2D wq_b_scale but
the decode-side spec now provides a flattened 1D wq_b_scale of shape [H *
HEAD_DIM]; update prefill_attention_swa to accept the 1D array and reshape it
internally to [H, HEAD_DIM] (or broadcast/reshape as needed) before using it, or
add a short validation/reshape step that detects 1D input and transforms it into
the 2D layout; refer to the prefill_attention_swa function and the wq_b_scale
parameter name (and any uses of HEAD_DIM / H inside that function) to implement
the change so argument binding matches the flattened contract.

---

Outside diff comments:
In `@models/deepseek/v4/prefill_qkv_proj_rope.py`:
- Line 396: The code is using 2D slicing on wq_b_scale but the buffer is flat
1D; replace the 2D slice wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] with a
flat slice using Q_PROJ_OUT_CHUNK, i.e. use indices (hbg + h_inner) *
Q_PROJ_OUT_CHUNK : (hbg + h_inner + 1) * Q_PROJ_OUT_CHUNK to extract w_scale,
and update any functions that declare or accept wq_b_scale to expect a 1D array
(adjust the signature and any callers that pass wq_b_scale accordingly—the two
places that currently declare/accept wq_b_scale should be changed to the 1D
contract).
- Around line 630-646: The loop over specs from _build_qkv_tensor_specs() fails
to override the wq_b_scale spec when prefill requires a 2D shape; update the
loop to detect spec.name == "wq_b_scale" and replace it with a TensorSpec that
reconstructs the 2D shape expected by the prefill kernel (for example
TensorSpec("wq_b_scale", [T, HEAD_DIM], same dtype as original) while preserving
the original init_value and is_output flags); this ensures the prefill
implementation (the kernel that assumes a 2D contract) receives the correct
shape instead of the unmodified 1D spec.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: d97415d5-46c8-4f25-9172-144af6f43e5f

📥 Commits

Reviewing files that changed from the base of the PR and between a71551e and dc2e1e0.

📒 Files selected for processing (10)
  • models/deepseek/v4/decode_attention_csa.py
  • models/deepseek/v4/decode_attention_hca.py
  • models/deepseek/v4/decode_attention_swa.py
  • models/deepseek/v4/decode_csa.py
  • models/deepseek/v4/decode_hca.py
  • models/deepseek/v4/decode_qkv_proj_rope.py
  • models/deepseek/v4/decode_swa.py
  • models/deepseek/v4/prefill_attention_swa_draft.py
  • models/deepseek/v4/prefill_qkv_proj_rope.py
  • models/deepseek/v4/qkv_proj_rope.py
💤 Files with no reviewable changes (1)
  • models/deepseek/v4/qkv_proj_rope.py

Comment thread models/deepseek/v4/prefill_attention_swa_draft.py
The decode build_tensor_specs now produces wq_b_scale as 1D
[H * HEAD_DIM]; prefill imports those specs but kept the old 2D
[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] signature, which made shapes
disagree at jit time. Flatten the prefill signature too and slice +
reshape at the dequant use site.
Drop the obsolete _draft file. prefill_attention_swa_draft now imports
golden_prefill_sparse_attn from the non-draft prefill_sparse_attn
module.
@zhangqi-chen zhangqi-chen merged commit e0733ec into hw-native-sys:main May 26, 2026
5 of 7 checks passed
@zhangqi-chen zhangqi-chen deleted the dsv4-qkv-proj-rope-refactor branch May 26, 2026 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant