Speed up Qwen3-14B LM head matmul and add standalone test driver#391
Speed up Qwen3-14B LM head matmul and add standalone test driver#391luohuan19 wants to merge 2 commits into
Conversation
- Override LM_HEAD_K_CHUNK/VOCAB_CHUNK to 256 and add OB_CHUNK=4. Config defaults (64/128) left cube cores at ~45% utilisation behind dispatch bubbles; the wider N+K and ob chunking amortise per-task dispatch and lift the innermost K dim to one L2 line. - Hoist pl.at(CORE_GROUP, lm_head) outside the vocab pl.parallel and pass optimizations=[pl.auto_chunk] so the matmul-accumulate ladder can fuse. - Drop the lm_acc_gm staging tensor and second store scope; trim with pl.set_validshape (metadata only) instead of pl.slice with valid_shape, which forced a rejected acc->acc tmov. - Add test_rms_lm_head jit driver, build_tensor_specs, torch golden, and a __main__ entry so the kernel runs via the standard golden harness.
|
Warning Review limit reached
More reviews will be available in 15 minutes and 27 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughUpdated ChangesRMSNorm LM-head Tiling and Testing
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request optimizes the rms_lm_head kernel by overriding chunk configurations, applying pl.auto_chunk optimizations, and replacing pl.slice with pl.set_validshape to avoid unnecessary data movement. It also adds a JIT test function, tensor spec builders, and a PyTorch golden reference implementation. The review feedback highlights potential issues where dynamic batch sizes smaller than the loop offset could result in negative values for lm_valid_rows and cur_valid, potentially causing compilation or runtime errors. Clamping these values to a minimum of 0 is recommended.
| # forces a tmov to a new acc tile with a different slayout, | ||
| # which ptoas rejects ("expects a supported tmov address-space | ||
| # pair"). set_validshape is metadata-only — no data movement. | ||
| lm_acc_trimmed = pl.set_validshape(lm_acc, lm_valid_rows, VOCAB_CHUNK) |
There was a problem hiding this comment.
If user_batch is dynamic and less than b0, lm_valid_rows will be negative. Passing a negative row count to pl.set_validshape can lead to undefined behavior or compilation failures. Clamp lm_valid_rows to a minimum of 0 to prevent this.
| lm_acc_trimmed = pl.set_validshape(lm_acc, lm_valid_rows, VOCAB_CHUNK) | |
| lm_acc_trimmed = pl.set_validshape(lm_acc, pl.max(0, lm_valid_rows), VOCAB_CHUNK) |
| user_batch = pl.tensor.dim(hidden_states, 0) | ||
| current_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) | ||
| for b0 in pl.parallel(0, BATCH, BATCH_TILE): | ||
| cur_valid = pl.min(BATCH_TILE, user_batch - b0) |
There was a problem hiding this comment.
When user_batch is less than b0 (which can happen when user_batch is dynamic and smaller than BATCH), user_batch - b0 will be negative. This results in a negative cur_valid value being passed to pl.slice as part of valid_shape, which can cause compilation errors or undefined runtime behavior. Clamp cur_valid to a minimum of 0 using pl.max(0, ...) to ensure safety.
| cur_valid = pl.min(BATCH_TILE, user_batch - b0) | |
| cur_valid = pl.max(0, pl.min(BATCH_TILE, user_batch - b0)) |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
models/qwen3/14b/rms_lm_head.py (1)
335-339: 💤 Low valueUnused
vocab_chunkparameter.The
vocab_chunkparameter is defined but never used intiled_lm_head. The function only tiles along the K dimension, not the vocab dimension. This doesn't affect correctness but is dead code.♻️ Remove unused parameter
- def tiled_lm_head(lhs, rhs_t, k_chunk, vocab_chunk): + def tiled_lm_head(lhs, rhs_t, k_chunk): out = torch.zeros(lhs.shape[0], rhs_t.shape[0], dtype=torch.float32) for k0 in range(0, lhs.shape[1], k_chunk): out = out + lhs[:, k0 : k0 + k_chunk].float() @ rhs_t[:, k0 : k0 + k_chunk].float().T return outAnd update the call site at line 345-350:
tensors["out"][:] = tiled_lm_head( final_normed, lm_head_weight, LM_HEAD_K_CHUNK, - VOCAB_CHUNK, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/rms_lm_head.py` around lines 335 - 339, The function tiled_lm_head declares an unused vocab_chunk parameter; remove vocab_chunk from the tiled_lm_head signature and from all its call sites so the function only accepts (lhs, rhs_t, k_chunk), and update any references/imports or tests that pass vocab_chunk to call the new signature; ensure behavior is unchanged by leaving the internal loop over k_chunk as-is in tiled_lm_head.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 335-339: The function tiled_lm_head declares an unused vocab_chunk
parameter; remove vocab_chunk from the tiled_lm_head signature and from all its
call sites so the function only accepts (lhs, rhs_t, k_chunk), and update any
references/imports or tests that pass vocab_chunk to call the new signature;
ensure behavior is unchanged by leaving the internal loop over k_chunk as-is in
tiled_lm_head.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 5f8efdcc-7132-41a9-9eec-70bcc848719a
📒 Files selected for processing (1)
models/qwen3/14b/rms_lm_head.py
- LM_HEAD_K_CHUNK 256 -> 512 (one cube op spans 2 L2 lines, K-loop trip 20 -> 10) - LM_HEAD_OB_CHUNK 4 -> 2 (149 -> 297 dispatches, halves the tail bubble on the last core) Wall-clock on a2a3 b=16: 1524 -> 1447 us (5% extra on top of the prior 2.5x). Within 5.5% of the 24-core compute floor (~1367 us).
2149695 to
ac5eedc
Compare
Summary
rms_lm_headLM-head matmul: overrideLM_HEAD_K_CHUNK=512/VOCAB_CHUNK=256and addLM_HEAD_OB_CHUNK=2to widen cube tasks and amortise dispatch; hoistpl.at(CORE_GROUP, "lm_head")outside the vocabpl.parallelwithoptimizations=[pl.auto_chunk]so the matmul/accumulate ladder can fuse.lm_acc_gmstaging tensor and the second store scope; trim withpl.set_validshape(metadata-only) instead ofpl.slice(..., valid_shape=...)on the accumulator, which forced a tmov that ptoas rejects.test_rms_lm_head@pl.jitdriver withbuild_tensor_specs, a torchgolden_rms_lm_headreference, and a__main__entry so the kernel runs via the standard golden harness (python models/qwen3/14b/rms_lm_head.py -p <platform>).