chore(dsv4): migrate chunked_loop_optimizer to auto_chunk (#388)#389
chore(dsv4): migrate chunked_loop_optimizer to auto_chunk (#388)#389wangqin1723-max wants to merge 1 commit into
Conversation
|
Warning Review limit reached
More reviews will be available in 21 minutes and 52 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans include higher PR review limits than trial, open-source, and free plans. In all cases, reviews become available again over time. During sustained high-volume PR review activity, CodeRabbit may temporarily slow when the next review becomes available. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
📝 WalkthroughWalkthroughFour DeepSeek v4 kernel modules migrate loop optimization directives from ChangesLoop Optimization Migration to Auto_Chunk
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request replaces the deprecated optimization=pl.chunked_loop_optimizer argument with optimizations=[pl.auto_chunk] across multiple DeepSeek v4 model files, including decode_attention_hca.py, decode_attention_swa.py, hc_post.py, and qkv_proj_rope.py. The inline documentation and comments are also updated to reflect this API change. No review comments were provided, so there is no feedback to address.
00c1e82 to
f0f2ad4
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/decode_attention_swa.py`:
- Around line 163-173: The 16-wide parallel loop assumes full 16-lane chunks and
can read/write past valid batch rows when B % 16 != 0; in the
pl.parallel(pl.parallel(0, B, 16)) block (and the similar block at lines
191-193) replace the fixed pl.range(b0, b0 + 16) with a guarded range or mask by
computing the actual tail width (e.g., tail = min(16, B - b0)) and iterate
pl.range(b0, b0 + tail) or conditionally skip/avoid assemble for b >= B, and
ensure the pl.assemble into kv_cache_flat only happens for valid b (and uses the
same guarded index compute involving block_table_flat, S, s_idx, BLOCK_SIZE,
HEAD_DIM) so no out-of-bounds reads/writes occur.
In `@models/deepseek/v4/hc_post.py`:
- Around line 46-71: The inner loop `for t in pl.range(t0, t0 + 16)` can iterate
past the valid token count when T%16 != 0; modify the iteration to respect the
global T bound by limiting t to < T (e.g., compute t_end = min(t0+16, T) and
iterate pl.range(t0, t_end) or keep the existing range and skip iterations with
an if t >= T guard). Apply this change where the loop over t appears (the block
reading post_flat, x_flat, residual_flat and assembling y_flat) so all uses of t
(e.g., reads from post_flat at [t * HC_MULT + out_h], slices using [t, ...], and
writes to y_flat at [t, ...]) are protected from out-of-bounds access.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: bcc0f626-c162-4c41-bf28-4537083cef26
📒 Files selected for processing (4)
models/deepseek/v4/decode_attention_hca.pymodels/deepseek/v4/decode_attention_swa.pymodels/deepseek/v4/hc_post.pymodels/deepseek/v4/qkv_proj_rope.py
| for t0 in pl.parallel(0, T, 16): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="hc_post"): | ||
| for t in pl.range(t0, t0 + 16): | ||
| post_w = pl.read(post_flat, [t * HC_MULT + out_h]) | ||
| for db in pl.range(D_BLOCKS): | ||
| d0 = db * D_CHUNK | ||
| x_row = pl.cast( | ||
| pl.slice(x_flat, [1, D_CHUNK], [t, d0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| y_row = pl.add(y_row, pl.mul(residual_row, comb_w)) | ||
| y_flat = pl.assemble( | ||
| y_flat, | ||
| pl.cast(y_row, target_type=pl.BF16, mode="rint"), | ||
| [t, out_h * D + d0], | ||
| ) | ||
| y_row = pl.mul(x_row, post_w) | ||
| for in_h in pl.range(HC_MULT): | ||
| comb_w = pl.read( | ||
| comb_flat, | ||
| [t * HC_MULT * HC_MULT + in_h * HC_MULT + out_h], | ||
| ) | ||
| residual_row = pl.cast( | ||
| pl.slice(residual_flat, [1, D_CHUNK], [t, in_h * D + d0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| y_row = pl.add(y_row, pl.mul(residual_row, comb_w)) | ||
| y_flat = pl.assemble( | ||
| y_flat, | ||
| pl.cast(y_row, target_type=pl.BF16, mode="rint"), | ||
| [t, out_h * D + d0], | ||
| ) |
There was a problem hiding this comment.
Add tail bound check for the new 16-wide t chunk loop.
for t in pl.range(t0, t0 + 16) can run beyond valid token rows on the last chunk when T % 16 != 0.
💡 Suggested fix
for out_h in pl.parallel(HC_MULT):
for t0 in pl.parallel(0, T, 16):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="hc_post"):
for t in pl.range(t0, t0 + 16):
- post_w = pl.read(post_flat, [t * HC_MULT + out_h])
- for db in pl.range(D_BLOCKS):
- d0 = db * D_CHUNK
- x_row = pl.cast(
- pl.slice(x_flat, [1, D_CHUNK], [t, d0]),
- target_type=pl.FP32,
- )
- y_row = pl.mul(x_row, post_w)
- for in_h in pl.range(HC_MULT):
- comb_w = pl.read(
- comb_flat,
- [t * HC_MULT * HC_MULT + in_h * HC_MULT + out_h],
- )
- residual_row = pl.cast(
- pl.slice(residual_flat, [1, D_CHUNK], [t, in_h * D + d0]),
- target_type=pl.FP32,
- )
- y_row = pl.add(y_row, pl.mul(residual_row, comb_w))
- y_flat = pl.assemble(
- y_flat,
- pl.cast(y_row, target_type=pl.BF16, mode="rint"),
- [t, out_h * D + d0],
- )
+ if t < T:
+ post_w = pl.read(post_flat, [t * HC_MULT + out_h])
+ for db in pl.range(D_BLOCKS):
+ d0 = db * D_CHUNK
+ x_row = pl.cast(
+ pl.slice(x_flat, [1, D_CHUNK], [t, d0]),
+ target_type=pl.FP32,
+ )
+ y_row = pl.mul(x_row, post_w)
+ for in_h in pl.range(HC_MULT):
+ comb_w = pl.read(
+ comb_flat,
+ [t * HC_MULT * HC_MULT + in_h * HC_MULT + out_h],
+ )
+ residual_row = pl.cast(
+ pl.slice(residual_flat, [1, D_CHUNK], [t, in_h * D + d0]),
+ target_type=pl.FP32,
+ )
+ y_row = pl.add(y_row, pl.mul(residual_row, comb_w))
+ y_flat = pl.assemble(
+ y_flat,
+ pl.cast(y_row, target_type=pl.BF16, mode="rint"),
+ [t, out_h * D + d0],
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/hc_post.py` around lines 46 - 71, The inner loop `for t in
pl.range(t0, t0 + 16)` can iterate past the valid token count when T%16 != 0;
modify the iteration to respect the global T bound by limiting t to < T (e.g.,
compute t_end = min(t0+16, T) and iterate pl.range(t0, t_end) or keep the
existing range and skip iterations with an if t >= T guard). Apply this change
where the loop over t appears (the block reading post_flat, x_flat,
residual_flat and assembling y_flat) so all uses of t (e.g., reads from
post_flat at [t * HC_MULT + out_h], slices using [t, ...], and writes to y_flat
at [t, ...]) are protected from out-of-bounds access.
| topk_idxs = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) | ||
| for t0 in pl.range(0, T, HCA_TOPK_CHUNK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="hca_topk"): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="hca_topk"): |
| sparse_topk = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) | ||
| for b0 in pl.range(0, T, SWA_BATCH_CHUNK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="swa_topk"): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="swa_topk"): |
| cmp_block_table_dummy = pl.create_tensor([B, SPARSE_CMP_MAX_BLOCKS], dtype=pl.INT32) | ||
| for b0 in pl.range(0, B, SWA_BATCH_CHUNK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="swa_cmp_dummy"): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="swa_cmp_dummy"): |
bea63ab to
c29b69c
Compare
…pk scopes Replace parallel(chunk=) with explicit parallel+range and migrate the remaining chunked_loop_optimizer sites to auto_chunk; drop auto_chunk from hca_topk/swa_topk/swa_cmp_dummy scopes.
Summary
Closes #388.
chunked_loop_optimizeris deprecated upstream; the rest of the repo already usesauto_chunk.Commit 1 — swap the 7 remaining sites to
auto_chunk:decode_attention_hca.py— hca_topkdecode_attention_swa.py— swa_scatter_kv / swa_topk / swa_cmp_dummyhc_post.py— hc_postqkv_proj_rope.py— attn_norm_rms_partial / qr_rms_partialCommit 2 — for the two
pl.parallel(0,N,1,chunk=16)sites (hc_post, swa_scatter_kv), drop the optimizer entirely and make chunking explicit:pl.parallel(0,N,16) + pl.at + pl.range(16). The other 5 sites stay onauto_chunk(not parallel-chunk loops).Validation
Run on a2a3 (with
PTO2_RING_*env), all PASS, precision-neutral:decode_hca.pydecode_csa.pydecode_attention_swa.py