perf(deepseek-v4): tune prefill sparse attention#394
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRefactors prefill sparse-attn: updates tiling/cache constants and batch default, stages tiled sparse-KV materialization, tiles causal QK/softmax and PV/merge-norm into partial FP32 buffers, stages RoPE selectors, updates golden reference for optional compressed inputs and tiled reduction, and exposes compress-ratio CLI/harness options. ChangesPrefill Attention Computation and Validation
Possibly Related PRs
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the prefill sparse attention kernel by optimizing tile sizes, splitting block outputs into separate tensors, and merging the merge and normalization phases. It also updates the golden reference implementation to align with the tiled attention block logic and adjusts command-line arguments. The review feedback suggests adding a compile-time assertion to prevent silent correctness bugs when the number of attention blocks exceeds two, and simplifying a nested ternary expression for better readability.
| 0 : HEAD_DIM, | ||
| ] | ||
|
|
||
| if PREFILL_ATTN_BLOCKS > 1: |
There was a problem hiding this comment.
The current implementation of the merge-norm phase only supports up to 2 attention blocks (PREFILL_ATTN_BLOCKS <= 2) because it hardcodes the second block (merge_norm_block_row1) and uses separate tensors prefill_blk_oi0 and prefill_blk_oi1. If S is increased such that PREFILL_ATTN_BLOCKS > 2, any blocks beyond the second will be silently ignored, leading to incorrect attention results.
To prevent silent correctness bugs, we should add a compile-time check/assertion.
| if PREFILL_ATTN_BLOCKS > 1: | |
| if PREFILL_ATTN_BLOCKS > 2: | |
| raise ValueError(f"prefill_sparse_attn currently only supports up to 2 attention blocks, but PREFILL_ATTN_BLOCKS={PREFILL_ATTN_BLOCKS}") | |
| if PREFILL_ATTN_BLOCKS > 1: |
| B_N_CHUNK = 128 if T >= 128 else 256 | ||
| QUANT_CHUNK = 32 if T >= 128 else (128 if T >= 64 else 256) | ||
| QUANT_TOKEN_TILE = 8 | ||
| QUANT_CHUNK = 128 if T >= 128 else (128 if T >= 64 else 256) |
There was a problem hiding this comment.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
models/deepseek/v4/prefill_sparse_attn.py (2)
304-305: 💤 Low valueConsider clarifying the zero-tensor creation.
pl.sub(merge_norm_mi, merge_norm_mi)creates a zero tensor with the shape ofmerge_norm_mi. While correct, the intent isn't immediately obvious. A brief inline comment would help future readers understand this is creating a zero-filled tensor to addmerge_norm_sink_biasfor broadcasting.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 304 - 305, The expression pl.sub(merge_norm_mi, merge_norm_mi) is being used to produce a zero tensor for broadcasting into merge_norm_sink_tile; update the code around the merge_norm_sink_bias and merge_norm_sink_tile lines to clarify intent by either replacing pl.sub(...) with a clearer helper (e.g., pl.zeros_like(merge_norm_mi) if available) or adding a concise inline comment explaining “create zero tensor of merge_norm_mi shape for broadcasting” immediately above/next to the merge_norm_sink_tile assignment (referencing merge_norm_mi, merge_norm_sink_bias, merge_norm_sink_tile and pl.sub).
245-248: ⚡ Quick winPV buffering is safe for current
S(noPREFILL_ATTN_BLOCKS > 2path today)
This file hardcodesS = 128and usesPREFILL_ATTN_TILE = 64, soPREFILL_ATTN_BLOCKS = (S + PREFILL_ATTN_TILE - 1) // PREFILL_ATTN_TILEevaluates to2; the PV loop only runspv_sb = 0, 1, soprefill_blk_oi1isn’t overwritten bypv_sb >= 2. Document/guard theSvsPREFILL_ATTN_TILErelationship if these constants are expected to change.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 245 - 248, The code assumes PREFILL_ATTN_BLOCKS == 2 so pv_sb only takes 0 and 1, but that makes prefill_blk_oi1 safe only under current constants (S=128, PREFILL_ATTN_TILE=64); add an explicit guard or assertion to prevent silent bugs if S or PREFILL_ATTN_TILE change: validate that PREFILL_ATTN_BLOCKS <= 2 (or that the PV loop range won’t exceed 2) before the PV loop and/or document the invariant near the definitions of S, PREFILL_ATTN_TILE, and PREFILL_ATTN_BLOCKS; reference the names PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS, S, pv_sb, prefill_blk_oi0 and prefill_blk_oi1 when adding the check or comment so future maintainers see the dependency.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 304-305: The expression pl.sub(merge_norm_mi, merge_norm_mi) is
being used to produce a zero tensor for broadcasting into merge_norm_sink_tile;
update the code around the merge_norm_sink_bias and merge_norm_sink_tile lines
to clarify intent by either replacing pl.sub(...) with a clearer helper (e.g.,
pl.zeros_like(merge_norm_mi) if available) or adding a concise inline comment
explaining “create zero tensor of merge_norm_mi shape for broadcasting”
immediately above/next to the merge_norm_sink_tile assignment (referencing
merge_norm_mi, merge_norm_sink_bias, merge_norm_sink_tile and pl.sub).
- Around line 245-248: The code assumes PREFILL_ATTN_BLOCKS == 2 so pv_sb only
takes 0 and 1, but that makes prefill_blk_oi1 safe only under current constants
(S=128, PREFILL_ATTN_TILE=64); add an explicit guard or assertion to prevent
silent bugs if S or PREFILL_ATTN_TILE change: validate that PREFILL_ATTN_BLOCKS
<= 2 (or that the PV loop range won’t exceed 2) before the PV loop and/or
document the invariant near the definitions of S, PREFILL_ATTN_TILE, and
PREFILL_ATTN_BLOCKS; reference the names PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS,
S, pv_sb, prefill_blk_oi0 and prefill_blk_oi1 when adding the check or comment
so future maintainers see the dependency.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 4216acfb-98ab-4a0e-ab4a-b18cc1a57db5
📒 Files selected for processing (1)
models/deepseek/v4/prefill_sparse_attn.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 74-75: The computed padding PREFILL_SPARSE_PAD can exceed TOPK if
tiling/padding parameters change; after the PREFILL_SPARSE_TOPK /
PREFILL_ATTN_TILE / PREFILL_ATTN_BLOCKS calculations (symbols:
PREFILL_SPARSE_TOPK, PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS, PREFILL_SPARSE_PAD,
TOPK) add a guard/assert that PREFILL_SPARSE_PAD <= TOPK and fail fast with a
clear message if violated so future config changes cannot reintroduce
out-of-bounds indexing into cmp_sparse_indices.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: abf8ff5c-d55c-43ef-965c-1278e7a71760
📒 Files selected for processing (1)
models/deepseek/v4/prefill_sparse_attn.py
| PREFILL_ATTN_BLOCKS = (PREFILL_SPARSE_TOPK + PREFILL_ATTN_TILE - 1) // PREFILL_ATTN_TILE | ||
| PREFILL_SPARSE_PAD = PREFILL_ATTN_BLOCKS * PREFILL_ATTN_TILE |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether the model config guarantees PREFILL_SPARSE_PAD <= TOPK
# Search for FLASH config to understand the sliding_window and index_topk values
echo "=== Searching for FLASH config definition ==="
rg -n "FLASH\s*=" --type py -C 3
echo ""
echo "=== Searching for sliding_window and index_topk in config ==="
rg -n "(sliding_window|index_topk)" --type py -C 2
echo ""
echo "=== Checking if there are any assertions on these constants ==="
rg -n "assert.*PREFILL_SPARSE_PAD|assert.*TOPK" --type pyRepository: hw-native-sys/pypto-lib
Length of output: 13045
Avoid current out-of-bounds for FLASH; add an invariant for future configs
For FLASH (models/deepseek/v4/config.py), sliding_window=128 and index_topk=512, so TOPK=640. With the module’s current constants (PREFILL_MAX_COMPRESSED=32, PREFILL_SPARSE_TOPK=min(640, 128+32)=160, PREFILL_ATTN_TILE=64), PREFILL_SPARSE_PAD=ceil(160/64)*64=192, so iteration up to PREFILL_SPARSE_PAD stays within cmp_sparse_indices’s [T, TOPK] contract.
No assertion currently enforces PREFILL_SPARSE_PAD <= TOPK (the only related check found is models/deepseek/v4/gate.py:38 assert TOPK <= TOPK_PAD), so add an invariant/guard to prevent reintroducing the padded-over-TOPK edge case if S/tiling/padding logic changes.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 74 - 75, The computed
padding PREFILL_SPARSE_PAD can exceed TOPK if tiling/padding parameters change;
after the PREFILL_SPARSE_TOPK / PREFILL_ATTN_TILE / PREFILL_ATTN_BLOCKS
calculations (symbols: PREFILL_SPARSE_TOPK, PREFILL_ATTN_TILE,
PREFILL_ATTN_BLOCKS, PREFILL_SPARSE_PAD, TOPK) add a guard/assert that
PREFILL_SPARSE_PAD <= TOPK and fail fast with a clear message if violated so
future config changes cannot reintroduce out-of-bounds indexing into
cmp_sparse_indices.
Summary
prefill_sparse_attn.pyfor the current B=1, S=128 performance bring-up.ATTN_TOKEN_TILE=32.MERGE_NORM_TOKEN_TILE=16to reduce long AIV dependency bubbles.prefill_rope_slice_tileand intermediateo_proj_even/o_proj_oddscratch tensors.QUANT_TOKEN_TILE=32andQUANT_CHUNK=128.Notes
Validation
python3 -m py_compile models/deepseek/v4/prefill_sparse_attn.pygit diff --checkattn_outPASS,max_error_ratio=0.005