Qwen3-14B decode: mix-fuse cube+vec epilogues, promote scope-3 / scope-2 to flat pl.spmd#387
Conversation
…e-2 to flat pl.spmd
Reworks Qwen3-14B decode_layer.py so every cube K-loop reduction that was
followed by a vec epilogue runs as ONE mixed cube+vec pl.spmd body with
UP_DOWN row split, replacing the previous (pl.parallel outer + pl.at inner
+ inner pl.spmd) shapes and the GM bridge tiles they needed.
Mixed cube+vec fusion (5 split regions -> 3 mixed roots):
- out_proj + out_proj_residual -> one mixed root (name "out_proj")
- down_proj + down_proj_residual -> one mixed root (name "down_proj");
drops the per-iter fp32_chunk_gm GM scratch.
- gate_proj + up_proj + silu -> one mixed root (name "gate_up_silu");
drops the gate_group / up_group FP32 GM bridges and the two cube
matmuls share a single K-loop so each post_chunk is loaded from L1
once per K-tile and feeds both wg and wu.
Each mixed root uses optimizations=[pl.split(pl.SplitMode.UP_DOWN)] so
cube + vec ping-pong on BATCH_TILE/2 rows each; without UP_DOWN the
per-core UB budget is exceeded under --max-seq.
pl.parallel -> pl.spmd promotion:
- out_proj / gate_up_silu / down_proj are now top-level pl.spmd
regions (implicitly InCore, no surrounding pl.at).
scope-2 flat SPMD dispatch:
- fa_fused and online_softmax are promoted out of
for b in pl.parallel(user_batch) into top-level flat
pl.spmd(BATCH * (TOTAL_Q_GROUPS // 2)) (= 64) and
pl.spmd(BATCH * TOTAL_Q_GROUPS) (= 128) dispatches. Each block
decodes spmd_idx -> (b, g2 or gi) and re-reads its per-b ctx_blocks.
Total task count is unchanged but the runtime sees ONE big spmd pool
to load-balance across cube cores instead of 16 small per-batch
dispatches.
- all_oi_tmp / all_cur_mi / all_cur_li promoted from per-b to global GM
tensors covering all BATCH batches. attn_row per-b intermediate is
gone -- online_softmax writes directly to attn_out at
[b, q_base * HEAD_DIM].
- rope_kv_cache stays in for b in pl.parallel(user_batch) because the
K/V cache slot write + fa_fused QK/SV cache read can't share an
InCore region (cross-region barrier is required, otherwise codegen
fails with "Tensor view not found for parameter: k_cache__tile").
--lm-head CLI for perf testing:
- Adds --lm-head {full,skip,single} to decode_layer.py.
- skip: skips the LM-head matmul entirely (out stays at zero init).
Uses new rms_only variant in rms_lm_head.py.
- single: runs only the first VOCAB_CHUNK iteration. Uses new
rms_lm_head_single_chunk variant.
- Golden helpers call the full reference body then zero out the
columns the kernel doesn't write.
NOTE: fa_fused's pl.pipeline(2, stage=2) over the Q-group pair is
currently commented out in this branch (gi = fa_g2 * 2 only, only 4
of 8 KV heads processed). This was for perf-isolation experiments
and is masked by --lm-head skip. Restore the two commented lines
before running real attention validation (--lm-head single / full);
the comment in fa_fused notes this.
Validation:
- task-submit a2a3, --max-seq --lm-head skip: PASS (out (16, 152064)
trivially zero == golden zero).
- task-submit a2a3, --max-seq --lm-head single, pl.pipeline restored:
PASS (rtol/atol 1.5e-2, real attention through one vocab chunk
matches the PyTorch reference).
|
Warning Review limit reached
More reviews will be available in 20 minutes and 19 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR refactors the Qwen3-14B single-layer decode kernel by restructuring Scope 2 attention dispatch to flat SPMD pools with promoted global scratch tensors, fusing Scope 3 MLP computation with shared K-loop and inline SiLU, and introducing two RMSNorm kernel variants ( ChangesDecode Layer Kernel Refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request optimizes the Qwen3 14B decode layer by promoting attention stages (fa_fused and online_softmax) to top-level flat SPMD dispatches and fusing projection and MLP layers into mixed cube+vec regions to eliminate global memory round-trips. It also introduces test variants to skip or run a single chunk of the LM head for faster profiling. The review feedback highlights three key issues: the attention head pipelining is currently commented out (hardcoding gi and processing only even heads), the final_normed tensor in the new rms_only function is unused and prone to dead-code elimination, and the tightened test tolerances contradict comments regarding cross-lane drift on a5sim platforms.
…spmd # Conflicts: # models/qwen3/14b/decode_layer.py
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
models/qwen3/14b/rms_lm_head.py (1)
116-119: ⚡ Quick winClarify:
outzero-init is guaranteed by the golden harness for these variants
The harness zero-initializes output tensors whenTensorSpec(..., is_output=True)is used with the defaultinit_value=None(golden/runner.pysetstorch.zeros(...)for pure outputs). Inmodels/qwen3/14b/decode_layer.py,TensorSpec("out", [batch, vocab], torch.float32, is_output=True)is declared withoutinit_value, and the only callers ofrms_only/rms_lm_head_single_chunkare through this JIT golden flow—so the “unwritten regions stay zero” assumption holds for correctness here. If these entry points are reused outside the golden harness, they should then write/clear the fullout(or document that required external contract).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/rms_lm_head.py` around lines 116 - 119, Clarify that the zero-initialization of the out tensor is provided by the golden harness: mention that TensorSpec("out", ..., is_output=True) in models/qwen3/14b/decode_layer.py (with default init_value=None) is zeroed by golden/runner.py, so rms_only and rms_lm_head_single_chunk (variants of rms_lm_head) rely on that contract and intentionally skip writing the LM-head matmul; also add a note in the rms_only / rms_lm_head_single_chunk docstring or comment that if these functions are ever used outside the golden harness they must explicitly zero or fully write the out buffer (or callers must guarantee zero-init) to maintain correctness.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 434-436: The current change sets gi = fa_g2 * 2 which causes
fa_fused to write only even Q-groups while online_softmax still iterates over
every os_gi and reads unwritten odd-group scratch leading to incorrect attn_out;
restore the 2-way Q-group loop by reintroducing the inner pipeline (use for gp
in pl.pipeline(2, stage=2): gi = fa_g2 * 2 + gp) so both gi values are handled
per fa_fused, or alternatively adjust the outer dispatch to iterate every gi;
update code around fa_g2, gi, pl.pipeline, fa_fused, online_softmax and os_gi to
ensure both even and odd Q-groups are written before online_softmax reads
attn_out.
- Around line 424-433: The loop that iterates fa_spmd_idx reads
USER_BATCH_DYN-backed tensors (seq_lens, block_table) using computed fa_b/fa_g2
even when padded SPMD lanes exceed runtime_user_batch; fix by masking padded
lanes before any dynamic-tensor access: add a guard like if fa_b >=
runtime_user_batch then take a zero-work path (skip reads/writes and set
fa_ctx_len/fa_ctx_blocks/fa_block_table_base to safe defaults) so no
USER_BATCH_DYN tensors are indexed from padded lanes; apply the same pattern to
the corresponding os_spmd loop that computes os_b/os_g2 before touching
seq_lens/block_table.
- Around line 1320-1321: The rtol/atol values in decode_layer.py are set too
tight (rtol=3e-3, atol=3e-3) despite the nearby comment and validation showing a
~1.1–1.5e-2 drift; change the hardcoded rtol and atol to the validated threshold
(e.g., 1.5e-2) or wrap the tighter 3e-3 values behind a conditional that only
applies when running the exact measured mode/platform, updating the rtol and
atol settings used by the code path that performs the numeric comparison.
---
Nitpick comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 116-119: Clarify that the zero-initialization of the out tensor is
provided by the golden harness: mention that TensorSpec("out", ...,
is_output=True) in models/qwen3/14b/decode_layer.py (with default
init_value=None) is zeroed by golden/runner.py, so rms_only and
rms_lm_head_single_chunk (variants of rms_lm_head) rely on that contract and
intentionally skip writing the LM-head matmul; also add a note in the rms_only /
rms_lm_head_single_chunk docstring or comment that if these functions are ever
used outside the golden harness they must explicitly zero or fully write the out
buffer (or callers must guarantee zero-init) to maintain correctness.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: ad3f2a1c-fbcb-4437-bcc5-e2ef90177f02
📒 Files selected for processing (2)
models/qwen3/14b/decode_layer.pymodels/qwen3/14b/rms_lm_head.py
- Restore pl.pipeline(2, stage=2) in fa_fused -- was commented out, only even Q-groups (4 of 8 KV heads) were running and odd-group all_oi_tmp slots stayed uninitialised, which caused online_softmax to read NaN and the a2a3 CI to validate as NaN (gemini #1, coderabbit hw-native-sys#5). - Clamp fa_b / os_b reads of USER_BATCH_DYN-backed tensors (seq_lens, block_table) with pl.min(., user_batch - 1) so padded spmd lanes (fa_b >= user_batch when runtime_user_batch < BATCH) never index past the dynamic dim. Writes still use the raw padded index and land in their own padded slot of the BATCH-sized global scratch, which the host trims away (coderabbit hw-native-sys#4). - Drop the now-unused MLP_GROUP_CHUNK import (ruff F401, fixes pre-commit). - Document the DCE risk on rms_only.final_normed with a TODO that points at the fix if perf traces ever show the RMSNorm cost collapsing to zero (gemini hw-native-sys#2).
Summary
Reworks
models/qwen3/14b/decode_layer.py(and adds two variants tomodels/qwen3/14b/rms_lm_head.py) so every cube K-loop reduction that was followed by a vec epilogue runs as ONE mixed cube+vecpl.spmdbody with UP_DOWN row split, replacing the previous (pl.parallelouter +pl.atinner / innerpl.spmd) shapes and the GM bridge tiles they needed.Mixed cube+vec fusion (5 split regions -> 3 mixed roots)
Each mixed root uses
optimizations=[pl.split(pl.SplitMode.UP_DOWN)]so cube + vec ping-pong onBATCH_TILE/2rows each; without UP_DOWN the per-core UB budget is exceeded under--max-seq.name_hint="out_proj").o_accstays on L0C across the K-loop and feeds the vec residual via the C2V boundary move (no GM round-trip).name_hint="down_proj"); drops the per-iterfp32_chunk_gmGM scratch (~160 KiB / call).name_hint="gate_up_silu"); drops thegate_group/up_groupFP32 GM bridges and the two cube matmuls share a SINGLE K-loop so eachpost_chunkis loaded from L1 once per K-tile and feeds bothwgandwu.pl.parallel->pl.spmdpromotionout_proj/gate_up_silu/down_projare now top-levelpl.spmd(BLOCK_COUNT, optimizations=[UP_DOWN])regions instead ofpl.paralleloutside apl.at-- spmd body is implicitly InCore, one fewer wrapper layer.scope-2 flat SPMD dispatch
fa_fusedandonline_softmaxare promoted out offor b in pl.parallel(user_batch)into top-level flatpl.spmd(BATCH * (TOTAL_Q_GROUPS // 2))(= 64) andpl.spmd(BATCH * TOTAL_Q_GROUPS)(= 128) dispatches. Each block decodesspmd_idx -> (b, g2 or gi)and re-reads its per-bctx_blocksfromseq_lens. Total task count is unchanged, but the runtime sees ONE big spmd pool to load-balance across cube cores instead of 16 small per-batch dispatches; per-batch launch + barrier overhead is gone.all_oi_tmp/all_cur_mi/all_cur_liare promoted from per-b to global GM tensors covering allBATCHbatches (+~32 MB GM, well within budget).attn_rowper-b intermediate is gone --online_softmaxwrites directly toattn_out[b, q_base * HEAD_DIM].rope_kv_cachestays infor b in pl.parallel(user_batch)because the K/V cache slot write +fa_fusedQK/SV cache read can't share an InCore region (cross-region barrier is required, otherwise codegen fails withTensor view not found for parameter: k_cache__tile).--lm-headCLI for perf testingAdds
--lm-head {full,skip,single}todecode_layer.py(defaultfull, original behaviour unchanged):skip-- skips the LM-head matmul entirely (~2376 vocab-chunk iterations);outstays at the harness zero init. Uses newrms_onlyvariant inrms_lm_head.py.single-- runs only the firstVOCAB_CHUNKiteration; columns pastVOCAB_CHUNKstay zero. Uses newrms_lm_head_single_chunkvariant.Golden helpers (
golden_decode_layer_no_lm_head,golden_decode_layer_single_lm_head) call the full reference body and then zero out the columns the kernel doesn't write -- relies onrun_jit's zero-init of output tensors (verified ingolden/runner.py:275,377).pl.pipelinestate to reviewfa_fused'spl.pipeline(2, stage=2)over the Q-group pair is currently commented out in this PR (gi = fa_g2 * 2only, processing 4 of 8 KV heads). This was for perf-isolation experiments and is masked by--lm-head skip(output forced to zero, attention correctness not validated). Restore the two commented lines before running real validation (--lm-head single/full); the comment in the code already notes this.Validation
task-submit --device auto --run --max-time 0 "... python decode_layer.py --max-seq --lm-head skip": PASS (out (16, 152064)trivially zero == golden zero, end-to-end run on a2a3).task-submit --device auto --run --max-time 0 "... python decode_layer.py --max-seq --lm-head single"withpl.pipelinerestored: PASS (rtol/atol 1.5e-2, real attention through one vocab chunk matches the PyTorch reference).