Add DeepSeek V4 prefill QKV RoPE tile#374
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a token-chunked DeepSeek-V4 prefill Q/KV projection kernel with partial RoPE tiling, introduces prefill tuning constants, implements a JIT core (RMSNorm, LoRA, per-row INT8 quantization, partial RoPE, INT8 matmul/dequant, BF16 outputs), and provides a Torch golden reference, tensor-spec builder, and CLI validator. ChangesDeepSeek-V4 Prefill Q/KV Projection Kernel
Sequence DiagramsequenceDiagram
participant Activations as Activations (input)
participant RMSNorm as RMSNorm + Gamma
participant LoRA_A as LoRA A
participant Quant as INT8 Quantize (qr -> int8 + scale)
participant KV_RoPE as Partial RoPE (KV tail)
participant INT8MatMul as W8A8C16 MatMul (qr, wq_b)
participant RoPE_Q as Partial RoPE (Q tail)
participant Outputs as Outputs (q, kv, qr, qr_scale)
Activations->>RMSNorm: chunked flatten & normalize
RMSNorm->>LoRA_A: produce LoRA A query projection
LoRA_A->>Quant: normalize & quantize qr (int8 + scale)
Activations->>KV_RoPE: KV path RMSNorm + partial RoPE
Quant->>INT8MatMul: qr_tile (int8) + qr_scale
INT8MatMul->>RoPE_Q: dequantize & apply per-head RMSNorm, partial RoPE
RoPE_Q->>Outputs: assemble interleaved heads, write BF16 q/kv and qr/qr_scale
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
models/deepseek/v4/prefill_qkv_proj_rope.py (1)
408-408: 💤 Low valueInconsistent use of
pl.rsqrtvspl.recip(pl.sqrt(...)).This line uses
pl.rsqrt(...)while all other inverse-RMS computations in this file (lines 147, 209, 288) and in the decode kernel usepl.recip(pl.sqrt(...)). While mathematically equivalent, different instruction paths could cause subtle numerical differences that may affect validation consistency.Suggested fix for consistency
- q_head_inv_rms = pl.rsqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)) + q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_qkv_proj_rope.py` at line 408, The line computing q_head_inv_rms uses pl.rsqrt(...) which is inconsistent with other inverse-RMS calculations; replace the pl.rsqrt(...) expression in the q_head_inv_rms assignment with the equivalent pl.recip(pl.sqrt(...)) form (matching the pattern used elsewhere) so the computation for q_head_inv_rms uses pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) and thus aligns numerically with the other inverse-RMS uses.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/deepseek/v4/prefill_qkv_proj_rope.py`:
- Line 408: The line computing q_head_inv_rms uses pl.rsqrt(...) which is
inconsistent with other inverse-RMS calculations; replace the pl.rsqrt(...)
expression in the q_head_inv_rms assignment with the equivalent
pl.recip(pl.sqrt(...)) form (matching the pattern used elsewhere) so the
computation for q_head_inv_rms uses
pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) and thus
aligns numerically with the other inverse-RMS uses.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 023d4d38-d85a-4832-9a4c-4b961702d334
📒 Files selected for processing (2)
models/deepseek/v4/config.pymodels/deepseek/v4/prefill_qkv_proj_rope.py
There was a problem hiding this comment.
Code Review
This pull request introduces a new JIT-compiled kernel for DeepSeek-V4 prefill Q/KV projection and partial RoPE application, along with corresponding deployment configuration constants. The implementation covers RMSNorm, LoRA projections, and W8A8C16 quantization. Review feedback identifies opportunities to optimize the kernel by moving loop-invariant frequency slicing outside the batch tile loop and utilizing the pl.rsqrt primitive for more efficient RMSNorm inverse calculations across the query and KV paths.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
models/deepseek/v4/prefill_qkv_proj_rope.py (1)
402-405:⚠️ Potential issue | 🟠 Major | ⚡ Quick winRestore
pl.rsqrt(...)on the q-head RMS path.Line 405 regresses the per-head q normalization to
pl.recip(pl.sqrt(...)), but the decode kernel and the golden q path both usersqrthere. That rounding delta is exactly on the post-dequant q path we compare, so this can drift from the validated decode behavior.Proposed fix
- q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) + q_head_inv_rms = pl.rsqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))Run this to verify the mismatch against the existing decode path and golden reference:
#!/bin/bash set -euo pipefail rg -n -C2 'q_head_inv_rms\s*=' models/deepseek/v4/prefill_qkv_proj_rope.py models/deepseek/v4/qkv_proj_rope.py rg -n -C2 'torch\.rsqrt|pl\.rsqrt|pl\.recip\(pl\.sqrt' models/deepseek/v4/prefill_qkv_proj_rope.py models/deepseek/v4/qkv_proj_rope.pyExpected result: the validated decode q path and golden reference show
rsqrt, while this prefill q-head RMS path showspl.recip(pl.sqrt(...)).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_qkv_proj_rope.py` around lines 402 - 405, The q-head RMS calculation currently uses pl.recip(pl.sqrt(...)) which causes rounding differences; change the computation that assigns q_head_inv_rms to use pl.rsqrt(...) instead (i.e., call pl.rsqrt on the same argument pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)), so the prefill q normalization matches the decode/golden q path (refer to the q_head_inv_rms variable, HEAD_DIM, EPS, and pl.rsqrt).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@models/deepseek/v4/prefill_qkv_proj_rope.py`:
- Around line 402-405: The q-head RMS calculation currently uses
pl.recip(pl.sqrt(...)) which causes rounding differences; change the computation
that assigns q_head_inv_rms to use pl.rsqrt(...) instead (i.e., call pl.rsqrt on
the same argument pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)), so the
prefill q normalization matches the decode/golden q path (refer to the
q_head_inv_rms variable, HEAD_DIM, EPS, and pl.rsqrt).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: cc60a769-e828-47b3-97d7-6edc67ee8079
📒 Files selected for processing (2)
models/deepseek/v4/config.pymodels/deepseek/v4/prefill_qkv_proj_rope.py
cf2c52c to
6c5c7e1
Compare
Summary
PREFILL_BATCH=1andPREFILL_SEQ=128for the current QKV RoPE kernel invocation, while keeping internal token chunking for projection, quantization, and RoPE scopes.Tassumptions; reuse only the stable QKV tensor-spec helper for golden inputs.start_posand align the q-path RMS inverse math with the existing decode QKV core.Validation
python3 -m py_compile models/deepseek/v4/config.py models/deepseek/v4/prefill_qkv_proj_rope.pypython models/deepseek/v4/prefill_qkv_proj_rope.pypython models/deepseek/v4/prefill_qkv_proj_rope.py --start-pos 128Both remote NPU runs passed for
q,kv,qr, andqr_scale.