Skip to content

Speed up Qwen3-14B LM head matmul and add standalone test driver#391

Open
luohuan19 wants to merge 2 commits into
hw-native-sys:mainfrom
luohuan19:rms_lm_head_speedup
Open

Speed up Qwen3-14B LM head matmul and add standalone test driver#391
luohuan19 wants to merge 2 commits into
hw-native-sys:mainfrom
luohuan19:rms_lm_head_speedup

Conversation

@luohuan19
Copy link
Copy Markdown
Contributor

Summary

  • Speed up the Qwen3-14B rms_lm_head LM-head matmul: override LM_HEAD_K_CHUNK=512 / VOCAB_CHUNK=256 and add LM_HEAD_OB_CHUNK=2 to widen cube tasks and amortise dispatch; hoist pl.at(CORE_GROUP, "lm_head") outside the vocab pl.parallel with optimizations=[pl.auto_chunk] so the matmul/accumulate ladder can fuse.
  • Drop the lm_acc_gm staging tensor and the second store scope; trim with pl.set_validshape (metadata-only) instead of pl.slice(..., valid_shape=...) on the accumulator, which forced a tmov that ptoas rejects.
  • Add a standalone test_rms_lm_head @pl.jit driver with build_tensor_specs, a torch golden_rms_lm_head reference, and a __main__ entry so the kernel runs via the standard golden harness (python models/qwen3/14b/rms_lm_head.py -p <platform>).
  • Wall-clock on a2a3 b=16: 1524 -> 1447 us (5% on top of the prior 2.5x); within 5.5% of the 24-core compute floor (~1367 us).

- Override LM_HEAD_K_CHUNK/VOCAB_CHUNK to 256 and add OB_CHUNK=4.
  Config defaults (64/128) left cube cores at ~45% utilisation
  behind dispatch bubbles; the wider N+K and ob chunking amortise
  per-task dispatch and lift the innermost K dim to one L2 line.
- Hoist pl.at(CORE_GROUP, lm_head) outside the vocab pl.parallel
  and pass optimizations=[pl.auto_chunk] so the matmul-accumulate
  ladder can fuse.
- Drop the lm_acc_gm staging tensor and second store scope; trim
  with pl.set_validshape (metadata only) instead of pl.slice with
  valid_shape, which forced a rejected acc->acc tmov.
- Add test_rms_lm_head jit driver, build_tensor_specs, torch
  golden, and a __main__ entry so the kernel runs via the
  standard golden harness.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

Warning

Review limit reached

@luohuan19, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 15 minutes and 27 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 9b7348f1-4eef-475f-8e05-9c263ea5abeb

📥 Commits

Reviewing files that changed from the base of the PR and between 2149695 and ac5eedc.

📒 Files selected for processing (1)
  • models/qwen3/14b/rms_lm_head.py
📝 Walkthrough

Walkthrough

Updated rms_lm_head.py with local LM-head tiling constants (LM_HEAD_K_CHUNK, VOCAB_CHUNK, LM_HEAD_OB_CHUNK), modified kernel chunking strategy with pl.parallel and pl.auto_chunk, changed partial-row trimming to metadata-only pl.set_validshape, and added complete JIT test harness with tensor specs builder, PyTorch golden reference, and CLI.

Changes

RMSNorm LM-head Tiling and Testing

Layer / File(s) Summary
Tiling Constants and Kernel Optimization
models/qwen3/14b/rms_lm_head.py
Introduced module-level chunk-size constants (LM_HEAD_K_CHUNK, VOCAB_CHUNK, LM_HEAD_OB_CHUNK) to override config defaults. Modified rms_lm_head compute loop to parallelize LM-head outer iterations with pl.parallel(..., chunk=LM_HEAD_OB_CHUNK) and added pl.auto_chunk optimization. Changed invalid-row trimming from store-time valid_shape slicing to metadata-only pl.set_validshape(lm_acc, lm_valid_rows, VOCAB_CHUNK) applied before output assembly.
Test Harness and Validation Infrastructure
models/qwen3/14b/rms_lm_head.py
Added test_rms_lm_head JIT kernel that copies valid hidden_states using valid_shape-bounded slices and calls rms_lm_head. Added build_tensor_specs helper to generate randomized test inputs with dynamic batch behavior. Added golden_rms_lm_head PyTorch reference that performs chunked per-row RMSNorm followed by tiled LM-head matmul. Added __main__ CLI entry point with platform/device/batch/L2 arguments to run run_jit tests and report failures.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#331: Both PRs directly modify models/qwen3/14b/rms_lm_head.py implementation details including chunking and tiled matmul behavior for the Qwen3-14B LM-head.
  • hw-native-sys/pypto-lib#387: Both PRs modify models/qwen3/14b/rms_lm_head.py; the main PR changes tiling/validshape trimming and adds a test harness, while the retrieved PR adds rms_only and rms_lm_head_single_chunk variants that depend on the same RMSNorm and LM-head kernel details.

Poem

🐰 The RMS head now chunks with grace,
With validshape trimming in place,
Constants guide the tile design,
Golden tests make sure all align,
Parallel loops run—what a race!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: optimizing the Qwen3-14B LM head matmul and adding a standalone test driver.
Description check ✅ Passed The description provides comprehensive details about the optimization changes, test driver additions, and measured performance improvements, all directly related to the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the rms_lm_head kernel by overriding chunk configurations, applying pl.auto_chunk optimizations, and replacing pl.slice with pl.set_validshape to avoid unnecessary data movement. It also adds a JIT test function, tensor spec builders, and a PyTorch golden reference implementation. The review feedback highlights potential issues where dynamic batch sizes smaller than the loop offset could result in negative values for lm_valid_rows and cur_valid, potentially causing compilation or runtime errors. Clamping these values to a minimum of 0 is recommended.

# forces a tmov to a new acc tile with a different slayout,
# which ptoas rejects ("expects a supported tmov address-space
# pair"). set_validshape is metadata-only — no data movement.
lm_acc_trimmed = pl.set_validshape(lm_acc, lm_valid_rows, VOCAB_CHUNK)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If user_batch is dynamic and less than b0, lm_valid_rows will be negative. Passing a negative row count to pl.set_validshape can lead to undefined behavior or compilation failures. Clamp lm_valid_rows to a minimum of 0 to prevent this.

Suggested change
lm_acc_trimmed = pl.set_validshape(lm_acc, lm_valid_rows, VOCAB_CHUNK)
lm_acc_trimmed = pl.set_validshape(lm_acc, pl.max(0, lm_valid_rows), VOCAB_CHUNK)

user_batch = pl.tensor.dim(hidden_states, 0)
current_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16)
for b0 in pl.parallel(0, BATCH, BATCH_TILE):
cur_valid = pl.min(BATCH_TILE, user_batch - b0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When user_batch is less than b0 (which can happen when user_batch is dynamic and smaller than BATCH), user_batch - b0 will be negative. This results in a negative cur_valid value being passed to pl.slice as part of valid_shape, which can cause compilation errors or undefined runtime behavior. Clamp cur_valid to a minimum of 0 using pl.max(0, ...) to ensure safety.

Suggested change
cur_valid = pl.min(BATCH_TILE, user_batch - b0)
cur_valid = pl.max(0, pl.min(BATCH_TILE, user_batch - b0))

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
models/qwen3/14b/rms_lm_head.py (1)

335-339: 💤 Low value

Unused vocab_chunk parameter.

The vocab_chunk parameter is defined but never used in tiled_lm_head. The function only tiles along the K dimension, not the vocab dimension. This doesn't affect correctness but is dead code.

♻️ Remove unused parameter
-    def tiled_lm_head(lhs, rhs_t, k_chunk, vocab_chunk):
+    def tiled_lm_head(lhs, rhs_t, k_chunk):
         out = torch.zeros(lhs.shape[0], rhs_t.shape[0], dtype=torch.float32)
         for k0 in range(0, lhs.shape[1], k_chunk):
             out = out + lhs[:, k0 : k0 + k_chunk].float() @ rhs_t[:, k0 : k0 + k_chunk].float().T
         return out

And update the call site at line 345-350:

     tensors["out"][:] = tiled_lm_head(
         final_normed,
         lm_head_weight,
         LM_HEAD_K_CHUNK,
-        VOCAB_CHUNK,
     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/rms_lm_head.py` around lines 335 - 339, The function
tiled_lm_head declares an unused vocab_chunk parameter; remove vocab_chunk from
the tiled_lm_head signature and from all its call sites so the function only
accepts (lhs, rhs_t, k_chunk), and update any references/imports or tests that
pass vocab_chunk to call the new signature; ensure behavior is unchanged by
leaving the internal loop over k_chunk as-is in tiled_lm_head.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 335-339: The function tiled_lm_head declares an unused vocab_chunk
parameter; remove vocab_chunk from the tiled_lm_head signature and from all its
call sites so the function only accepts (lhs, rhs_t, k_chunk), and update any
references/imports or tests that pass vocab_chunk to call the new signature;
ensure behavior is unchanged by leaving the internal loop over k_chunk as-is in
tiled_lm_head.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 5f8efdcc-7132-41a9-9eec-70bcc848719a

📥 Commits

Reviewing files that changed from the base of the PR and between 2d6f086 and 2149695.

📒 Files selected for processing (1)
  • models/qwen3/14b/rms_lm_head.py

- LM_HEAD_K_CHUNK 256 -> 512 (one cube op spans 2 L2 lines, K-loop
  trip 20 -> 10)
- LM_HEAD_OB_CHUNK 4 -> 2 (149 -> 297 dispatches, halves the tail
  bubble on the last core)

Wall-clock on a2a3 b=16: 1524 -> 1447 us (5% extra on top of the
prior 2.5x). Within 5.5% of the 24-core compute floor (~1367 us).
@luohuan19 luohuan19 force-pushed the rms_lm_head_speedup branch from 2149695 to ac5eedc Compare May 26, 2026 13:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant