Skip to content

perf(deepseek-v4): tune prefill sparse attention#394

Open
sjduan wants to merge 3 commits into
hw-native-sys:mainfrom
sjduan:perf-prefill-sparse-attn
Open

perf(deepseek-v4): tune prefill sparse attention#394
sjduan wants to merge 3 commits into
hw-native-sys:mainfrom
sjduan:perf-prefill-sparse-attn

Conversation

@sjduan
Copy link
Copy Markdown
Contributor

@sjduan sjduan commented May 27, 2026

Summary

  • Tune DeepSeek V4 prefill_sparse_attn.py for the current B=1, S=128 performance bring-up.
  • Coarsen Stage2 attention task granularity with ATTN_TOKEN_TILE=32.
  • Split merge/norm into MERGE_NORM_TOKEN_TILE=16 to reduce long AIV dependency bubbles.
  • Fuse Stage3 RoPE slice/apply by removing prefill_rope_slice_tile and intermediate o_proj_even/o_proj_odd scratch tensors.
  • Increase Stage6 quant granularity with QUANT_TOKEN_TILE=32 and QUANT_CHUNK=128.

Notes

  • This keeps the softmax and PV scopes separated; the attempted fused softmax+PV variant was correct but slower in testing.
  • The Stage3 selector copy is intentional to avoid the transpose/non-transpose lowering conflict when using selector tensors in the same scope.

Validation

  • python3 -m py_compile models/deepseek/v4/prefill_sparse_attn.py
  • git diff --check
  • Remote A2A3 validation passed with L2 swimlane:
    • result: attn_out PASS, max_error_ratio=0.005

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 27, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Refactors prefill sparse-attn: updates tiling/cache constants and batch default, stages tiled sparse-KV materialization, tiles causal QK/softmax and PV/merge-norm into partial FP32 buffers, stages RoPE selectors, updates golden reference for optional compressed inputs and tiled reduction, and exposes compress-ratio CLI/harness options.

Changes

Prefill Attention Computation and Validation

Layer / File(s) Summary
Tiling / cache configuration and constants
models/deepseek/v4/prefill_sparse_attn.py
Adjusts standalone batch default (B → 1), tiling constants (ATTN_TOKEN_TILE, MERGE_NORM_TOKEN_TILE), INT8 quant tiling defaults, and compressed-cache sizing (SUPPORTED_COMPRESS_RATIOS, DEFAULT_COMPRESS_RATIO, conditional CMP_MAX_BLOCKS).
Kernel tensor allocation and staged selectors
models/deepseek/v4/prefill_sparse_attn.py
Adds sparse_kv, splits PV partial outputs into multiple FP32 buffers (prefill_blk_oi0/oi1/oi2), and introduces even_select_stage/odd_select_stage selector buffers.
Stage 1: Tiled sparse KV materialization
models/deepseek/v4/prefill_sparse_attn.py
Materializes a tiled sparse_kv from cmp_sparse_indices with negative-index zero-fill, <S gather from ori_kv, >=S gather from cmp_kv via cmp_block_table, and zero-padding beyond valid lengths.
Stage 2: Causal QK and softmax tiling
models/deepseek/v4/prefill_sparse_attn.py
Computes causal QK and softmax by iterating PREFILL_ATTN_BLOCKS, deriving per-tile validity from cmp_sparse_indices, slicing/padding tile scores accordingly, and reorganizing softmax intermediates.
PV, partial outputs and merge/norm
models/deepseek/v4/prefill_sparse_attn.py
PV uses prefill_exp with sparse_kv tiles, writes partial oi into up to three FP32 buffers per block, merges mi/li/oi across blocks via log-sum-exp style rescaling, and assembles final BF16 attention rows.
Inverse RoPE with staged selectors
models/deepseek/v4/prefill_sparse_attn.py
Adds explicit copy of provided *_select_local into even_select_stage/odd_select_stage and applies these staged selectors in inverse RoPE projection multiplication.
Golden reference and tiled causal ref
models/deepseek/v4/prefill_sparse_attn.py
Adds get_prefill_cmp_valid, makes golden_prefill_sparse_attn accept optional compressed inputs, builds KV via tiled sparse gathering, and computes causal attention in PREFILL_ATTN_TILE tiles with incremental mi/li/oi updates.
Tensor specs, sparse-index init, and harness wiring
models/deepseek/v4/prefill_sparse_attn.py
Parameterizes build_tensor_specs/runner by compress_ratio, updates init_cmp_sparse_indices to append compressed indices based on compress_ratio, updates harness runner to pass compress_ratio, and prints it in golden output.

Possibly Related PRs

  • hw-native-sys/pypto-lib#256: Related change that staged/wired even_select_local/odd_select_local through HCA/SWA into sparse_attn, matching selector staging in this PR.
  • hw-native-sys/pypto-lib#365: Prior refactor touching the same prefill_sparse_attn.py kernel and golden harness behavior.

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 Tiles stack soft and neat,
Selectors copied, buffers meet,
Partial sums in FP32,
Log-sum-exp brings unity,
Sparse attention hops to beat!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: performance tuning of DeepSeek-V4 prefill sparse attention, which matches the primary focus of the changeset.
Description check ✅ Passed The description is directly related to the changeset, detailing specific tuning parameters, optimizations, and validation performed, which aligns with the changes made.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the prefill sparse attention kernel by optimizing tile sizes, splitting block outputs into separate tensors, and merging the merge and normalization phases. It also updates the golden reference implementation to align with the tiled attention block logic and adjusts command-line arguments. The review feedback suggests adding a compile-time assertion to prevent silent correctness bugs when the number of attention blocks exceeds two, and simplifying a nested ternary expression for better readability.

0 : HEAD_DIM,
]

if PREFILL_ATTN_BLOCKS > 1:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the merge-norm phase only supports up to 2 attention blocks (PREFILL_ATTN_BLOCKS <= 2) because it hardcodes the second block (merge_norm_block_row1) and uses separate tensors prefill_blk_oi0 and prefill_blk_oi1. If S is increased such that PREFILL_ATTN_BLOCKS > 2, any blocks beyond the second will be silently ignored, leading to incorrect attention results.

To prevent silent correctness bugs, we should add a compile-time check/assertion.

Suggested change
if PREFILL_ATTN_BLOCKS > 1:
if PREFILL_ATTN_BLOCKS > 2:
raise ValueError(f"prefill_sparse_attn currently only supports up to 2 attention blocks, but PREFILL_ATTN_BLOCKS={PREFILL_ATTN_BLOCKS}")
if PREFILL_ATTN_BLOCKS > 1:

B_N_CHUNK = 128 if T >= 128 else 256
QUANT_CHUNK = 32 if T >= 128 else (128 if T >= 64 else 256)
QUANT_TOKEN_TILE = 8
QUANT_CHUNK = 128 if T >= 128 else (128 if T >= 64 else 256)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This nested ternary expression can be simplified for better readability. Since both the T >= 128 and T >= 64 conditions result in 128, they can be combined into a single condition.

Suggested change
QUANT_CHUNK = 128 if T >= 128 else (128 if T >= 64 else 256)
QUANT_CHUNK = 128 if T >= 64 else 256

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
models/deepseek/v4/prefill_sparse_attn.py (2)

304-305: 💤 Low value

Consider clarifying the zero-tensor creation.

pl.sub(merge_norm_mi, merge_norm_mi) creates a zero tensor with the shape of merge_norm_mi. While correct, the intent isn't immediately obvious. A brief inline comment would help future readers understand this is creating a zero-filled tensor to add merge_norm_sink_bias for broadcasting.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 304 - 305, The
expression pl.sub(merge_norm_mi, merge_norm_mi) is being used to produce a zero
tensor for broadcasting into merge_norm_sink_tile; update the code around the
merge_norm_sink_bias and merge_norm_sink_tile lines to clarify intent by either
replacing pl.sub(...) with a clearer helper (e.g., pl.zeros_like(merge_norm_mi)
if available) or adding a concise inline comment explaining “create zero tensor
of merge_norm_mi shape for broadcasting” immediately above/next to the
merge_norm_sink_tile assignment (referencing merge_norm_mi,
merge_norm_sink_bias, merge_norm_sink_tile and pl.sub).

245-248: ⚡ Quick win

PV buffering is safe for current S (no PREFILL_ATTN_BLOCKS > 2 path today)
This file hardcodes S = 128 and uses PREFILL_ATTN_TILE = 64, so PREFILL_ATTN_BLOCKS = (S + PREFILL_ATTN_TILE - 1) // PREFILL_ATTN_TILE evaluates to 2; the PV loop only runs pv_sb = 0, 1, so prefill_blk_oi1 isn’t overwritten by pv_sb >= 2. Document/guard the S vs PREFILL_ATTN_TILE relationship if these constants are expected to change.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 245 - 248, The code
assumes PREFILL_ATTN_BLOCKS == 2 so pv_sb only takes 0 and 1, but that makes
prefill_blk_oi1 safe only under current constants (S=128, PREFILL_ATTN_TILE=64);
add an explicit guard or assertion to prevent silent bugs if S or
PREFILL_ATTN_TILE change: validate that PREFILL_ATTN_BLOCKS <= 2 (or that the PV
loop range won’t exceed 2) before the PV loop and/or document the invariant near
the definitions of S, PREFILL_ATTN_TILE, and PREFILL_ATTN_BLOCKS; reference the
names PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS, S, pv_sb, prefill_blk_oi0 and
prefill_blk_oi1 when adding the check or comment so future maintainers see the
dependency.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 304-305: The expression pl.sub(merge_norm_mi, merge_norm_mi) is
being used to produce a zero tensor for broadcasting into merge_norm_sink_tile;
update the code around the merge_norm_sink_bias and merge_norm_sink_tile lines
to clarify intent by either replacing pl.sub(...) with a clearer helper (e.g.,
pl.zeros_like(merge_norm_mi) if available) or adding a concise inline comment
explaining “create zero tensor of merge_norm_mi shape for broadcasting”
immediately above/next to the merge_norm_sink_tile assignment (referencing
merge_norm_mi, merge_norm_sink_bias, merge_norm_sink_tile and pl.sub).
- Around line 245-248: The code assumes PREFILL_ATTN_BLOCKS == 2 so pv_sb only
takes 0 and 1, but that makes prefill_blk_oi1 safe only under current constants
(S=128, PREFILL_ATTN_TILE=64); add an explicit guard or assertion to prevent
silent bugs if S or PREFILL_ATTN_TILE change: validate that PREFILL_ATTN_BLOCKS
<= 2 (or that the PV loop range won’t exceed 2) before the PV loop and/or
document the invariant near the definitions of S, PREFILL_ATTN_TILE, and
PREFILL_ATTN_BLOCKS; reference the names PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS,
S, pv_sb, prefill_blk_oi0 and prefill_blk_oi1 when adding the check or comment
so future maintainers see the dependency.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 4216acfb-98ab-4a0e-ab4a-b18cc1a57db5

📥 Commits

Reviewing files that changed from the base of the PR and between e84fb6f and 0051e7e.

📒 Files selected for processing (1)
  • models/deepseek/v4/prefill_sparse_attn.py

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 74-75: The computed padding PREFILL_SPARSE_PAD can exceed TOPK if
tiling/padding parameters change; after the PREFILL_SPARSE_TOPK /
PREFILL_ATTN_TILE / PREFILL_ATTN_BLOCKS calculations (symbols:
PREFILL_SPARSE_TOPK, PREFILL_ATTN_TILE, PREFILL_ATTN_BLOCKS, PREFILL_SPARSE_PAD,
TOPK) add a guard/assert that PREFILL_SPARSE_PAD <= TOPK and fail fast with a
clear message if violated so future config changes cannot reintroduce
out-of-bounds indexing into cmp_sparse_indices.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: abf8ff5c-d55c-43ef-965c-1278e7a71760

📥 Commits

Reviewing files that changed from the base of the PR and between 0051e7e and 5e25e1c.

📒 Files selected for processing (1)
  • models/deepseek/v4/prefill_sparse_attn.py

Comment on lines +74 to +75
PREFILL_ATTN_BLOCKS = (PREFILL_SPARSE_TOPK + PREFILL_ATTN_TILE - 1) // PREFILL_ATTN_TILE
PREFILL_SPARSE_PAD = PREFILL_ATTN_BLOCKS * PREFILL_ATTN_TILE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether the model config guarantees PREFILL_SPARSE_PAD <= TOPK

# Search for FLASH config to understand the sliding_window and index_topk values
echo "=== Searching for FLASH config definition ==="
rg -n "FLASH\s*=" --type py -C 3

echo ""
echo "=== Searching for sliding_window and index_topk in config ==="
rg -n "(sliding_window|index_topk)" --type py -C 2

echo ""
echo "=== Checking if there are any assertions on these constants ==="
rg -n "assert.*PREFILL_SPARSE_PAD|assert.*TOPK" --type py

Repository: hw-native-sys/pypto-lib

Length of output: 13045


Avoid current out-of-bounds for FLASH; add an invariant for future configs

For FLASH (models/deepseek/v4/config.py), sliding_window=128 and index_topk=512, so TOPK=640. With the module’s current constants (PREFILL_MAX_COMPRESSED=32, PREFILL_SPARSE_TOPK=min(640, 128+32)=160, PREFILL_ATTN_TILE=64), PREFILL_SPARSE_PAD=ceil(160/64)*64=192, so iteration up to PREFILL_SPARSE_PAD stays within cmp_sparse_indices’s [T, TOPK] contract.

No assertion currently enforces PREFILL_SPARSE_PAD <= TOPK (the only related check found is models/deepseek/v4/gate.py:38 assert TOPK <= TOPK_PAD), so add an invariant/guard to prevent reintroducing the padded-over-TOPK edge case if S/tiling/padding logic changes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 74 - 75, The computed
padding PREFILL_SPARSE_PAD can exceed TOPK if tiling/padding parameters change;
after the PREFILL_SPARSE_TOPK / PREFILL_ATTN_TILE / PREFILL_ATTN_BLOCKS
calculations (symbols: PREFILL_SPARSE_TOPK, PREFILL_ATTN_TILE,
PREFILL_ATTN_BLOCKS, PREFILL_SPARSE_PAD, TOPK) add a guard/assert that
PREFILL_SPARSE_PAD <= TOPK and fail fast with a clear message if violated so
future config changes cannot reintroduce out-of-bounds indexing into
cmp_sparse_indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant