Skip to content

Qwen3-14B decode: mix-fuse cube+vec epilogues, promote scope-3 / scope-2 to flat pl.spmd#387

Merged
zhangqi-chen merged 4 commits into
hw-native-sys:mainfrom
lwDavid:qwen3-14b-mix-fuse-spmd
May 26, 2026
Merged

Qwen3-14B decode: mix-fuse cube+vec epilogues, promote scope-3 / scope-2 to flat pl.spmd#387
zhangqi-chen merged 4 commits into
hw-native-sys:mainfrom
lwDavid:qwen3-14b-mix-fuse-spmd

Conversation

@lwDavid
Copy link
Copy Markdown
Contributor

@lwDavid lwDavid commented May 26, 2026

Summary

Reworks models/qwen3/14b/decode_layer.py (and adds two variants to models/qwen3/14b/rms_lm_head.py) so every cube K-loop reduction that was followed by a vec epilogue runs as ONE mixed cube+vec pl.spmd body with UP_DOWN row split, replacing the previous (pl.parallel outer + pl.at inner / inner pl.spmd) shapes and the GM bridge tiles they needed.

Mixed cube+vec fusion (5 split regions -> 3 mixed roots)

Each mixed root uses optimizations=[pl.split(pl.SplitMode.UP_DOWN)] so cube + vec ping-pong on BATCH_TILE/2 rows each; without UP_DOWN the per-core UB budget is exceeded under --max-seq.

  • out_proj + out_proj_residual -> one mixed root (name_hint="out_proj"). o_acc stays on L0C across the K-loop and feeds the vec residual via the C2V boundary move (no GM round-trip).
  • down_proj + down_proj_residual -> one mixed root (name_hint="down_proj"); drops the per-iter fp32_chunk_gm GM scratch (~160 KiB / call).
  • gate_proj + up_proj + silu -> one mixed root (name_hint="gate_up_silu"); drops the gate_group / up_group FP32 GM bridges and the two cube matmuls share a SINGLE K-loop so each post_chunk is loaded from L1 once per K-tile and feeds both wg and wu.

pl.parallel -> pl.spmd promotion

out_proj / gate_up_silu / down_proj are now top-level pl.spmd(BLOCK_COUNT, optimizations=[UP_DOWN]) regions instead of pl.parallel outside a pl.at -- spmd body is implicitly InCore, one fewer wrapper layer.

scope-2 flat SPMD dispatch

fa_fused and online_softmax are promoted out of for b in pl.parallel(user_batch) into top-level flat pl.spmd(BATCH * (TOTAL_Q_GROUPS // 2)) (= 64) and pl.spmd(BATCH * TOTAL_Q_GROUPS) (= 128) dispatches. Each block decodes spmd_idx -> (b, g2 or gi) and re-reads its per-b ctx_blocks from seq_lens. Total task count is unchanged, but the runtime sees ONE big spmd pool to load-balance across cube cores instead of 16 small per-batch dispatches; per-batch launch + barrier overhead is gone.

  • all_oi_tmp / all_cur_mi / all_cur_li are promoted from per-b to global GM tensors covering all BATCH batches (+~32 MB GM, well within budget).
  • attn_row per-b intermediate is gone -- online_softmax writes directly to attn_out[b, q_base * HEAD_DIM].
  • rope_kv_cache stays in for b in pl.parallel(user_batch) because the K/V cache slot write + fa_fused QK/SV cache read can't share an InCore region (cross-region barrier is required, otherwise codegen fails with Tensor view not found for parameter: k_cache__tile).

--lm-head CLI for perf testing

Adds --lm-head {full,skip,single} to decode_layer.py (default full, original behaviour unchanged):

  • skip -- skips the LM-head matmul entirely (~2376 vocab-chunk iterations); out stays at the harness zero init. Uses new rms_only variant in rms_lm_head.py.
  • single -- runs only the first VOCAB_CHUNK iteration; columns past VOCAB_CHUNK stay zero. Uses new rms_lm_head_single_chunk variant.

Golden helpers (golden_decode_layer_no_lm_head, golden_decode_layer_single_lm_head) call the full reference body and then zero out the columns the kernel doesn't write -- relies on run_jit's zero-init of output tensors (verified in golden/runner.py:275,377).

⚠️ pl.pipeline state to review

fa_fused's pl.pipeline(2, stage=2) over the Q-group pair is currently commented out in this PR (gi = fa_g2 * 2 only, processing 4 of 8 KV heads). This was for perf-isolation experiments and is masked by --lm-head skip (output forced to zero, attention correctness not validated). Restore the two commented lines before running real validation (--lm-head single / full); the comment in the code already notes this.

Validation

  • task-submit --device auto --run --max-time 0 "... python decode_layer.py --max-seq --lm-head skip": PASS (out (16, 152064) trivially zero == golden zero, end-to-end run on a2a3).
  • task-submit --device auto --run --max-time 0 "... python decode_layer.py --max-seq --lm-head single" with pl.pipeline restored: PASS (rtol/atol 1.5e-2, real attention through one vocab chunk matches the PyTorch reference).
  • Compile clean (0 warnings).

…e-2 to flat pl.spmd

Reworks Qwen3-14B decode_layer.py so every cube K-loop reduction that was
followed by a vec epilogue runs as ONE mixed cube+vec pl.spmd body with
UP_DOWN row split, replacing the previous (pl.parallel outer + pl.at inner
+ inner pl.spmd) shapes and the GM bridge tiles they needed.

Mixed cube+vec fusion (5 split regions -> 3 mixed roots):
- out_proj + out_proj_residual -> one mixed root (name "out_proj")
- down_proj + down_proj_residual -> one mixed root (name "down_proj");
  drops the per-iter fp32_chunk_gm GM scratch.
- gate_proj + up_proj + silu -> one mixed root (name "gate_up_silu");
  drops the gate_group / up_group FP32 GM bridges and the two cube
  matmuls share a single K-loop so each post_chunk is loaded from L1
  once per K-tile and feeds both wg and wu.

Each mixed root uses optimizations=[pl.split(pl.SplitMode.UP_DOWN)] so
cube + vec ping-pong on BATCH_TILE/2 rows each; without UP_DOWN the
per-core UB budget is exceeded under --max-seq.

pl.parallel -> pl.spmd promotion:
- out_proj / gate_up_silu / down_proj are now top-level pl.spmd
  regions (implicitly InCore, no surrounding pl.at).

scope-2 flat SPMD dispatch:
- fa_fused and online_softmax are promoted out of
  for b in pl.parallel(user_batch) into top-level flat
  pl.spmd(BATCH * (TOTAL_Q_GROUPS // 2)) (= 64) and
  pl.spmd(BATCH * TOTAL_Q_GROUPS) (= 128) dispatches. Each block
  decodes spmd_idx -> (b, g2 or gi) and re-reads its per-b ctx_blocks.
  Total task count is unchanged but the runtime sees ONE big spmd pool
  to load-balance across cube cores instead of 16 small per-batch
  dispatches.
- all_oi_tmp / all_cur_mi / all_cur_li promoted from per-b to global GM
  tensors covering all BATCH batches. attn_row per-b intermediate is
  gone -- online_softmax writes directly to attn_out at
  [b, q_base * HEAD_DIM].
- rope_kv_cache stays in for b in pl.parallel(user_batch) because the
  K/V cache slot write + fa_fused QK/SV cache read can't share an
  InCore region (cross-region barrier is required, otherwise codegen
  fails with "Tensor view not found for parameter: k_cache__tile").

--lm-head CLI for perf testing:
- Adds --lm-head {full,skip,single} to decode_layer.py.
- skip: skips the LM-head matmul entirely (out stays at zero init).
  Uses new rms_only variant in rms_lm_head.py.
- single: runs only the first VOCAB_CHUNK iteration. Uses new
  rms_lm_head_single_chunk variant.
- Golden helpers call the full reference body then zero out the
  columns the kernel doesn't write.

NOTE: fa_fused's pl.pipeline(2, stage=2) over the Q-group pair is
currently commented out in this branch (gi = fa_g2 * 2 only, only 4
of 8 KV heads processed). This was for perf-isolation experiments
and is masked by --lm-head skip. Restore the two commented lines
before running real attention validation (--lm-head single / full);
the comment in fa_fused notes this.

Validation:
- task-submit a2a3, --max-seq --lm-head skip: PASS (out (16, 152064)
  trivially zero == golden zero).
- task-submit a2a3, --max-seq --lm-head single, pl.pipeline restored:
  PASS (rtol/atol 1.5e-2, real attention through one vocab chunk
  matches the PyTorch reference).
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

Warning

Review limit reached

@lwDavid, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 20 minutes and 19 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: cb940ad5-4d12-40cf-94a4-21621eb273ec

📥 Commits

Reviewing files that changed from the base of the PR and between c93c444 and 0a86de1.

📒 Files selected for processing (2)
  • models/qwen3/14b/decode_layer.py
  • models/qwen3/14b/rms_lm_head.py
📝 Walkthrough

Walkthrough

This PR refactors the Qwen3-14B single-layer decode kernel by restructuring Scope 2 attention dispatch to flat SPMD pools with promoted global scratch tensors, fusing Scope 3 MLP computation with shared K-loop and inline SiLU, and introducing two RMSNorm kernel variants (rms_only and rms_lm_head_single_chunk) that enable expanded test coverage with tightened numerical tolerances.

Changes

Decode Layer Kernel Refactoring

Layer / File(s) Summary
RMSNorm and LM-head kernel variants
models/qwen3/14b/rms_lm_head.py
rms_only computes final RMSNorm without LM-head matmul; rms_lm_head_single_chunk computes RMSNorm then LM-head for only the first vocabulary chunk, enabling selective logit computation variants.
Scope 2 attention refactoring to flat SPMD dispatch
models/qwen3/14b/decode_layer.py
Global scratch tensors (all_oi_tmp, all_cur_mi, all_cur_li) promoted to span all batches; fa_fused and online_softmax restructured as top-level flat SPMD pools across batch×group lanes; attn_out assembled directly without per-batch intermediate. rope_kv_cache remains separate due to GM visibility constraints.
Scope 3 MLP and projection fusion
models/qwen3/14b/decode_layer.py
gate_up_silu fused mixed cube+vec SPMD body shares K-loop for gate/up matmuls and executes SiLU in vec epilogue, assembling BF16 directly into mlp_tile; down-projection updated to SPMD mixed-region path with vec residual epilogue, eliminating prior GM scratch bridge.
Test harness expansion and validation
models/qwen3/14b/decode_layer.py
New test entry points test_decode_layer_no_lm_head and test_decode_layer_single_lm_head with corresponding golden functions; --lm-head CLI option (full|skip|single) selects variant; numerical tolerances tightened to rtol=3e-3, atol=3e-3.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#360: Directly modifies Scope-2 attention SPMD dispatch and Scope-3 MLP grouping logic in decode_layer.py with similar flat-SPMD and scratch-tensor refactoring patterns.
  • hw-native-sys/pypto-lib#331: Refactors decode_layer.py interaction with RMSNorm/LM-head by introducing and using new variants like rms_only and rms_lm_head_single_chunk in test infrastructure.
  • hw-native-sys/pypto-lib#51: Refactors Scope 2 decode attention kernel dataflow (softmax validity, accumulator scratch, RoPE/cache plumbing) within the same attention dispatch path.

Suggested labels

enhancement

Poem

🐰 A layered kernel blooms anew,
With flattened SPMD paths so true,
RMSNorm splits three gentle ways,
While fused MLP lights the blaze,
The decoder learns to compute and sway! 🌟

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately captures the main changes: mix-fused cube+vec epilogues and promotion of scope-3/scope-2 to flat pl.spmd, matching the core refactoring described in the summary.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description accurately relates to the changeset, detailing the mixed cube+vec fusion, scope-2 flat SPMD dispatch, and --lm-head CLI additions that are reflected in the file summaries.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the Qwen3 14B decode layer by promoting attention stages (fa_fused and online_softmax) to top-level flat SPMD dispatches and fusing projection and MLP layers into mixed cube+vec regions to eliminate global memory round-trips. It also introduces test variants to skip or run a single chunk of the LM head for faster profiling. The review feedback highlights three key issues: the attention head pipelining is currently commented out (hardcoding gi and processing only even heads), the final_normed tensor in the new rms_only function is unused and prone to dead-code elimination, and the tightened test tolerances contradict comments regarding cross-lane drift on a5sim platforms.

Comment thread models/qwen3/14b/decode_layer.py Outdated
Comment thread models/qwen3/14b/rms_lm_head.py
Comment thread models/qwen3/14b/decode_layer.py
…spmd

# Conflicts:
#	models/qwen3/14b/decode_layer.py
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
models/qwen3/14b/rms_lm_head.py (1)

116-119: ⚡ Quick win

Clarify: out zero-init is guaranteed by the golden harness for these variants
The harness zero-initializes output tensors when TensorSpec(..., is_output=True) is used with the default init_value=None (golden/runner.py sets torch.zeros(...) for pure outputs). In models/qwen3/14b/decode_layer.py, TensorSpec("out", [batch, vocab], torch.float32, is_output=True) is declared without init_value, and the only callers of rms_only / rms_lm_head_single_chunk are through this JIT golden flow—so the “unwritten regions stay zero” assumption holds for correctness here. If these entry points are reused outside the golden harness, they should then write/clear the full out (or document that required external contract).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/rms_lm_head.py` around lines 116 - 119, Clarify that the
zero-initialization of the out tensor is provided by the golden harness: mention
that TensorSpec("out", ..., is_output=True) in models/qwen3/14b/decode_layer.py
(with default init_value=None) is zeroed by golden/runner.py, so rms_only and
rms_lm_head_single_chunk (variants of rms_lm_head) rely on that contract and
intentionally skip writing the LM-head matmul; also add a note in the rms_only /
rms_lm_head_single_chunk docstring or comment that if these functions are ever
used outside the golden harness they must explicitly zero or fully write the out
buffer (or callers must guarantee zero-init) to maintain correctness.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 434-436: The current change sets gi = fa_g2 * 2 which causes
fa_fused to write only even Q-groups while online_softmax still iterates over
every os_gi and reads unwritten odd-group scratch leading to incorrect attn_out;
restore the 2-way Q-group loop by reintroducing the inner pipeline (use for gp
in pl.pipeline(2, stage=2): gi = fa_g2 * 2 + gp) so both gi values are handled
per fa_fused, or alternatively adjust the outer dispatch to iterate every gi;
update code around fa_g2, gi, pl.pipeline, fa_fused, online_softmax and os_gi to
ensure both even and odd Q-groups are written before online_softmax reads
attn_out.
- Around line 424-433: The loop that iterates fa_spmd_idx reads
USER_BATCH_DYN-backed tensors (seq_lens, block_table) using computed fa_b/fa_g2
even when padded SPMD lanes exceed runtime_user_batch; fix by masking padded
lanes before any dynamic-tensor access: add a guard like if fa_b >=
runtime_user_batch then take a zero-work path (skip reads/writes and set
fa_ctx_len/fa_ctx_blocks/fa_block_table_base to safe defaults) so no
USER_BATCH_DYN tensors are indexed from padded lanes; apply the same pattern to
the corresponding os_spmd loop that computes os_b/os_g2 before touching
seq_lens/block_table.
- Around line 1320-1321: The rtol/atol values in decode_layer.py are set too
tight (rtol=3e-3, atol=3e-3) despite the nearby comment and validation showing a
~1.1–1.5e-2 drift; change the hardcoded rtol and atol to the validated threshold
(e.g., 1.5e-2) or wrap the tighter 3e-3 values behind a conditional that only
applies when running the exact measured mode/platform, updating the rtol and
atol settings used by the code path that performs the numeric comparison.

---

Nitpick comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 116-119: Clarify that the zero-initialization of the out tensor is
provided by the golden harness: mention that TensorSpec("out", ...,
is_output=True) in models/qwen3/14b/decode_layer.py (with default
init_value=None) is zeroed by golden/runner.py, so rms_only and
rms_lm_head_single_chunk (variants of rms_lm_head) rely on that contract and
intentionally skip writing the LM-head matmul; also add a note in the rms_only /
rms_lm_head_single_chunk docstring or comment that if these functions are ever
used outside the golden harness they must explicitly zero or fully write the out
buffer (or callers must guarantee zero-init) to maintain correctness.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: ad3f2a1c-fbcb-4437-bcc5-e2ef90177f02

📥 Commits

Reviewing files that changed from the base of the PR and between e0733ec and c93c444.

📒 Files selected for processing (2)
  • models/qwen3/14b/decode_layer.py
  • models/qwen3/14b/rms_lm_head.py

Comment thread models/qwen3/14b/decode_layer.py Outdated
Comment thread models/qwen3/14b/decode_layer.py Outdated
Comment thread models/qwen3/14b/decode_layer.py
lwDavid added 2 commits May 26, 2026 15:46
- Restore pl.pipeline(2, stage=2) in fa_fused -- was commented out, only
  even Q-groups (4 of 8 KV heads) were running and odd-group all_oi_tmp
  slots stayed uninitialised, which caused online_softmax to read NaN and
  the a2a3 CI to validate as NaN (gemini #1, coderabbit hw-native-sys#5).
- Clamp fa_b / os_b reads of USER_BATCH_DYN-backed tensors (seq_lens,
  block_table) with pl.min(., user_batch - 1) so padded spmd lanes
  (fa_b >= user_batch when runtime_user_batch < BATCH) never index past
  the dynamic dim. Writes still use the raw padded index and land in
  their own padded slot of the BATCH-sized global scratch, which the
  host trims away (coderabbit hw-native-sys#4).
- Drop the now-unused MLP_GROUP_CHUNK import (ruff F401, fixes pre-commit).
- Document the DCE risk on rms_only.final_normed with a TODO that points
  at the fix if perf traces ever show the RMSNorm cost collapsing to
  zero (gemini hw-native-sys#2).
@zhangqi-chen zhangqi-chen merged commit 2d6f086 into hw-native-sys:main May 26, 2026
6 of 7 checks passed
@lwDavid lwDavid deleted the qwen3-14b-mix-fuse-spmd branch May 26, 2026 08:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants