From 4ca326029793f25e3d9118e0ab7bd48556ba8479 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 13:04:19 +0800 Subject: [PATCH 1/8] dsv4: refactor qkv_proj_rope with scope fusion, T-tile parallelism, and constant inlining - Fuse attn_norm_rms + attn_norm_apply, qr_rms_norm + qr_norm_apply, qr_quant_amax + qr_quant_apply, qproj_matmul + qproj_dequant, kv_rms + kv_norm_nope scopes - T-tile parallelism for attn_norm and qr_rms_norm (per-token reduce) - qproj merged scope splits dequant on T_TILE to keep cube+vec under Vec UB - Halve task counts for qr_proj_matmul and qproj - Inline *_BLOCKS, _GROUP, and stale conditional CHUNK expressions - Rename *_CHUNK constants to *_TILE for consistency - Drop chunked_loop_optimizer and partial-sum scaffolding (Opt S/U) --- models/deepseek/v4/qkv_proj_rope.py | 371 +++++++++++----------------- 1 file changed, 147 insertions(+), 224 deletions(-) diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index a871ae31..cc2821fd 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -32,48 +32,18 @@ # Group constants control pl.parallel(0, N, GROUP) + pl.range(GROUP) folding — # how many logical chunks are fused into one InCore task. See Opt J/K/L/N/O/P # in docs/dsv4-qkv-proj-rope-perf-tuning.md for the per-scope sweep results. -ROPE_CHUNK = 64 -ROPE_PAIR_CHUNK = ROPE_CHUNK // 2 -HEAD_CHUNK = 64 -HEAD_GROUP = 8 -Q_PROJ_OUT_CHUNK = 128 -Q_PROJ_CHUNK = 512 # K-tile; doubled from 256 (Opt V) since cube was K-bound on qproj_matmul -Q_PROJ_GROUP = 8 # N-tile head-blocks fused into one qproj_matmul task -QR_NORM_GROUP = 8 # Q_LORA_CHUNK blocks fused into one qr_norm_apply task -ATTN_NORM_GROUP = 4 # D_CHUNK blocks fused into one attn_norm_apply task -KV_PROJ_GROUP = 1 # KV_CHUNK blocks fused into one kv_proj_matmul task -Q_PROJ_DEQUANT_GROUP = 32 # qproj_dequant decoupled from qproj_matmul with its own larger group -ATTN_RMS_PARTIALS = 2 # parallel workers for attn_norm_rms (Opt S); 2-way keeps FP32 reduce deterministic -QR_RMS_PARTIALS = 2 # parallel workers for qr_rms (Opt U); same precision argument as ATTN_RMS_PARTIALS +ROPE_TILE = 64 +ROPE_PAIR_TILE = ROPE_TILE // 2 +HEAD_TILE = 64 +Q_PROJ_OUT_TILE = 128 +Q_PROJ_TILE = 512 # K-tile Q_LORA_TILE = 32 -Q_LORA_CHUNK = Q_LORA_TILE -D_CHUNK = 128 if T >= 128 else (256 if T >= 64 else 512) -KV_CHUNK = 32 -QUANT_CHUNK = 32 if T >= 128 else (128 if T >= 64 else 256) -QUANT_APPLY_CHUNK = 256 -assert (H * HEAD_DIM) % (HEAD_CHUNK * HEAD_GROUP) == 0, \ - "HEAD_BLOCKS must be divisible by HEAD_GROUP" -assert ((H * HEAD_DIM) // Q_PROJ_OUT_CHUNK) % Q_PROJ_GROUP == 0, \ - "Q_PROJ_HEAD_BLOCKS must be divisible by Q_PROJ_GROUP" -assert (Q_LORA // Q_LORA_TILE) % QR_NORM_GROUP == 0, \ - "Q_BLOCKS must be divisible by QR_NORM_GROUP" -assert (D // D_CHUNK) % ATTN_NORM_GROUP == 0, \ - "D_BLOCKS must be divisible by ATTN_NORM_GROUP" -assert (HEAD_DIM // KV_CHUNK) % KV_PROJ_GROUP == 0, \ - "KV_BLOCKS must be divisible by KV_PROJ_GROUP" -assert ((H * HEAD_DIM) // Q_PROJ_OUT_CHUNK) % Q_PROJ_DEQUANT_GROUP == 0, \ - "Q_PROJ_HEAD_BLOCKS must be divisible by Q_PROJ_DEQUANT_GROUP" -assert (D // D_CHUNK) % ATTN_RMS_PARTIALS == 0, \ - "D_BLOCKS must be divisible by ATTN_RMS_PARTIALS" -assert (Q_LORA // Q_LORA_TILE) % QR_RMS_PARTIALS == 0, \ - "Q_BLOCKS must be divisible by QR_RMS_PARTIALS" -Q_BLOCKS = Q_LORA // Q_LORA_TILE -Q_PROJ_BLOCKS = Q_LORA // Q_PROJ_CHUNK -HEAD_BLOCKS = (H * HEAD_DIM) // HEAD_CHUNK -Q_PROJ_HEAD_BLOCKS = (H * HEAD_DIM) // Q_PROJ_OUT_CHUNK -HEAD_GROUP_BLOCKS = (H * HEAD_DIM) // (HEAD_CHUNK * HEAD_GROUP) -D_BLOCKS = D // D_CHUNK -KV_BLOCKS = HEAD_DIM // KV_CHUNK +D_TILE = 128 +KV_TILE = 32 +QUANT_TILE = 32 +T_TILE = 16 # T-axis sub-tile for qproj dequant (keeps cube+vec fused scope under Vec UB) +assert (H * HEAD_DIM) % (HEAD_TILE * 8) == 0, \ + "HEAD_BLOCKS must be divisible by 8" @pl.jit.inline @@ -82,7 +52,7 @@ def qkv_proj_rope( norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], @@ -97,131 +67,86 @@ def qkv_proj_rope( ): x_flat = pl.reshape(x, [T, D]) - # Stage 0.1: attn_norm RMS — parallel partial sum (Opt S). - # Single-task serial reduce was ~93us at S=2; split into ATTN_RMS_PARTIALS - # workers + a small final reduce. auto_chunk is REQUIRED here: - # without it the inner pl.range tile allocations accumulate and exceed the - # 192KB Vec UB at S=2/T=128 (verified by compile failure during tuning). - # PARTIALS=2 (not 4+) keeps the FP32 add associativity-free, preserving `q` - # validation across devices. - D_BLOCKS_PER_PARTIAL = D_BLOCKS // ATTN_RMS_PARTIALS - x_sq_partial = pl.create_tensor([ATTN_RMS_PARTIALS, T], dtype=pl.FP32) - for wg in pl.parallel(0, ATTN_RMS_PARTIALS, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="attn_norm_rms_partial"): - rms_d_base = wg * D_BLOCKS_PER_PARTIAL * D_CHUNK - local_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for rms_db in pl.range(D_BLOCKS_PER_PARTIAL): - rms_d0 = rms_d_base + rms_db * D_CHUNK - rms_x_chunk = pl.cast(x_flat[:, rms_d0 : rms_d0 + D_CHUNK], target_type=pl.FP32) - local_sum = pl.add(local_sum, pl.reshape(pl.row_sum(pl.mul(rms_x_chunk, rms_x_chunk)), [1, T])) - x_sq_partial[wg : wg + 1, :] = local_sum - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm_rms_final"): - x_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for w in pl.range(ATTN_RMS_PARTIALS): - x_sq_sum = pl.add(x_sq_sum, x_sq_partial[w : w + 1, :]) - x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) - - # Stage 0.2: fused norm + FP32->BF16 cast (Opt E folded token_x_cast_bf16 in; - # the intermediate `token_x_fp32` GM buffer is gone). ATTN_NORM_GROUP-chunked - # (Opt N) — token_x_bf16 is the only cross-iter loop-carried tensor. - x_inv_rms_t = pl.reshape(x_inv_rms, [T, 1]) + # Stage 0.1+0.2: attn_norm fused RMS + apply, T-tiled parallel (per-token reduction). token_x_bf16 = pl.create_tensor([T, D], dtype=pl.BF16) - for dbg in pl.parallel(0, D_BLOCKS, ATTN_NORM_GROUP): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm_apply"): - for d_inner in pl.range(ATTN_NORM_GROUP): - apply_d0 = (dbg + d_inner) * D_CHUNK - apply_x_chunk = pl.cast(x_flat[:, apply_d0 : apply_d0 + D_CHUNK], target_type=pl.FP32) - norm_w_chunk = pl.reshape(norm_w[apply_d0 : apply_d0 + D_CHUNK], [1, D_CHUNK]) + for tg in pl.parallel(0, T, T_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm"): + x_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + for rms_db in pl.range(D // D_TILE): + rms_d0 = rms_db * D_TILE + rms_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, rms_d0 : rms_d0 + D_TILE], target_type=pl.FP32) + x_sq_sum = pl.add(x_sq_sum, pl.reshape(pl.row_sum(pl.mul(rms_x_chunk, rms_x_chunk)), [1, T_TILE])) + x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) + x_inv_rms_t = pl.reshape(x_inv_rms, [T_TILE, 1]) + for apply_db in pl.range(D // D_TILE): + apply_d0 = apply_db * D_TILE + apply_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE], target_type=pl.FP32) + norm_w_chunk = pl.reshape(norm_w[apply_d0 : apply_d0 + D_TILE], [1, D_TILE]) x_normed = pl.col_expand_mul(pl.row_expand_mul(apply_x_chunk, x_inv_rms_t), norm_w_chunk) - token_x_bf16[:, apply_d0 : apply_d0 + D_CHUNK] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") + token_x_bf16[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") # Stage 1/2.1: qr = rms_norm(token_x @ wq_a, gamma_cq). # K loop uses pl.pipeline(stage=4) for 4-deep ping-pong on the D=4096 input # projection (D_BLOCKS=32, sufficient iter count for 4-stage replication). qr_fp32 = pl.create_tensor([T, Q_LORA], dtype=pl.FP32) - for qb in pl.parallel(0, Q_BLOCKS, 1): + for qbg in pl.parallel(0, Q_LORA // Q_LORA_TILE, 2): with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_proj_matmul"): - q_a_col0 = qb * Q_LORA_CHUNK - q_acc = pl.create_tensor([T, Q_LORA_CHUNK], dtype=pl.FP32) - for db in pl.pipeline(0, D_BLOCKS, stage=4): - qr_d0 = db * D_CHUNK - q_x_chunk_bf16 = token_x_bf16[:, qr_d0 : qr_d0 + D_CHUNK] - w_chunk = wq_a[qr_d0 : qr_d0 + D_CHUNK, q_a_col0 : q_a_col0 + Q_LORA_CHUNK] - if qr_d0 == 0: - q_acc = pl.matmul(q_x_chunk_bf16, w_chunk, out_dtype=pl.FP32) - else: - q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) - qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_CHUNK] = q_acc - - # Stage 2.1: qr_rms — same partial-sum pattern as attn_norm_rms (Opt U). - # Inner loop is cast-free (qr_fp32 is already FP32) so Vec pressure is lower - # than attn_norm_rms_partial, but auto_chunk is kept for parity. - Q_BLOCKS_PER_QR_PARTIAL = Q_BLOCKS // QR_RMS_PARTIALS - qr_sq_partial = pl.create_tensor([QR_RMS_PARTIALS, T], dtype=pl.FP32) - for wgr in pl.parallel(0, QR_RMS_PARTIALS, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="qr_rms_partial"): - qr_rms_q_base = wgr * Q_BLOCKS_PER_QR_PARTIAL * Q_LORA_CHUNK - qr_local_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for qr_rms_qb in pl.range(Q_BLOCKS_PER_QR_PARTIAL): - qr_rms_col0 = qr_rms_q_base + qr_rms_qb * Q_LORA_CHUNK - qr_rms_chunk = qr_fp32[:, qr_rms_col0 : qr_rms_col0 + Q_LORA_CHUNK] - qr_local_sum = pl.add(qr_local_sum, pl.reshape(pl.row_sum(pl.mul(qr_rms_chunk, qr_rms_chunk)), [1, T])) - qr_sq_partial[wgr : wgr + 1, :] = qr_local_sum - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_rms_final"): - qr_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for w in pl.range(QR_RMS_PARTIALS): - qr_sq_sum = pl.add(qr_sq_sum, qr_sq_partial[w : w + 1, :]) - qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) - - # Stage 2.2+2.3a partial: fused qr norm + FP32->BF16 cast + per-task amax (Opt T). - # Per-task amax is computed on qr_normed_bf16 (the same BF16 representation the - # original qr_quant_amax scope would have re-read from GM), preserving the - # bit-identical INT8 quant scale required by `qr`'s atol=1 validation. - qr_inv_rms_t = pl.reshape(qr_inv_rms, [T, 1]) + for q_inner in pl.range(2): + q_a_col0 = (qbg + q_inner) * Q_LORA_TILE + q_acc = pl.create_tensor([T, Q_LORA_TILE], dtype=pl.FP32) + for db in pl.pipeline(0, D // D_TILE, stage=2): + qr_d0 = db * D_TILE + q_x_chunk_bf16 = token_x_bf16[:, qr_d0 : qr_d0 + D_TILE] + w_chunk = wq_a[qr_d0 : qr_d0 + D_TILE, q_a_col0 : q_a_col0 + Q_LORA_TILE] + if qr_d0 == 0: + q_acc = pl.matmul(q_x_chunk_bf16, w_chunk, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) + qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_TILE] = q_acc + + # Stage 2.1+2.2: fused qr_rms + qr_norm_apply + per-token amax, T-tiled parallel. qr_bf16 = pl.create_tensor([T, Q_LORA], dtype=pl.BF16) - qr_amax_partial = pl.create_tensor([Q_BLOCKS, T], dtype=pl.FP32) - for qbg in pl.parallel(0, Q_BLOCKS, QR_NORM_GROUP): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_norm_apply"): - local_amax = pl.full([1, T], dtype=pl.FP32, value=INT8_AMAX_EPS) - for q_inner in pl.range(QR_NORM_GROUP): - qr_norm_col0 = (qbg + q_inner) * Q_LORA_CHUNK - qr_norm_chunk = qr_fp32[:, qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK] + qr_amax_tensor = pl.create_tensor([1, T], dtype=pl.FP32) + for tg in pl.parallel(0, T, T_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_rms_norm"): + qr_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + for qr_rms_qb in pl.range(Q_LORA // Q_LORA_TILE): + qr_rms_col0 = qr_rms_qb * Q_LORA_TILE + qr_rms_chunk = qr_fp32[tg : tg + T_TILE, qr_rms_col0 : qr_rms_col0 + Q_LORA_TILE] + qr_sq_sum = pl.add(qr_sq_sum, pl.reshape(pl.row_sum(pl.mul(qr_rms_chunk, qr_rms_chunk)), [1, T_TILE])) + qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) + qr_inv_rms_t = pl.reshape(qr_inv_rms, [T_TILE, 1]) + + qr_tile_amax = pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) + for qb in pl.range(Q_LORA // Q_LORA_TILE): + qr_norm_col0 = qb * Q_LORA_TILE + qr_norm_chunk = qr_fp32[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] gamma_chunk = pl.reshape( - pl.cast(gamma_cq[qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK], target_type=pl.FP32), - [1, Q_LORA_CHUNK], + pl.cast(gamma_cq[qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE], target_type=pl.FP32), + [1, Q_LORA_TILE], ) qr_normed = pl.col_expand_mul(pl.row_expand_mul(qr_norm_chunk, qr_inv_rms_t), gamma_chunk) qr_normed_bf16 = pl.cast(qr_normed, target_type=pl.BF16, mode="rint") - qr_bf16[:, qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK] = qr_normed_bf16 + qr_bf16[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] = qr_normed_bf16 qr_norm_amax_f32 = pl.cast(qr_normed_bf16, target_type=pl.FP32) qr_norm_amax_abs = pl.maximum(qr_norm_amax_f32, pl.neg(qr_norm_amax_f32)) - local_amax = pl.maximum(local_amax, pl.reshape(pl.row_max(qr_norm_amax_abs), [1, T])) - qr_amax_partial[qbg : qbg + 1, :] = local_amax + qr_tile_amax = pl.maximum(qr_tile_amax, pl.reshape(pl.row_max(qr_norm_amax_abs), [1, T_TILE])) + qr_amax_tensor[0:1, tg : tg + T_TILE] = qr_tile_amax - # Stage 2.3a: final amax reduce + INT8 quant scale (Opt T leaves only the - # cheap reduce + scale here; the 256-iter serial amax body is gone). + # Stage 2.3: fused INT8 quant scale + apply (single-scope serial). qr_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) - qr_scale_quant_t = pl.create_tensor([T, 1], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_quant_amax"): - qr_amax = pl.full([1, T], dtype=pl.FP32, value=INT8_AMAX_EPS) - for w in pl.range(0, Q_BLOCKS, QR_NORM_GROUP): - qr_amax = pl.maximum(qr_amax, qr_amax_partial[w : w + 1, :]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_quant"): + qr_amax = qr_amax_tensor[0:1, :] qr_scale_quant_row = pl.div(pl.full([1, T], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_amax) + qr_scale_quant_t = pl.reshape(qr_scale_quant_row, [T, 1]) qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]) qr_scale[:, :] = qr_scale_dq - qr_scale_quant_t[:, :] = pl.reshape(qr_scale_quant_row, [T, 1]) - - # Stage 2.3b: apply quantization scale (parallel over Q_LORA chunks). - for qa in pl.parallel(0, Q_LORA, QUANT_APPLY_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_quant_apply"): - for q1 in pl.range(0, QUANT_APPLY_CHUNK, QUANT_CHUNK): - qr_q_f32 = pl.cast(qr_bf16[:, qa + q1 : qa + q1 + QUANT_CHUNK], target_type=pl.FP32) - qr_q_scaled = pl.row_expand_mul(qr_q_f32, qr_scale_quant_t) - qr_q_i32 = pl.cast(qr_q_scaled, target_type=pl.INT32, mode="rint") - qr_q_half = pl.cast(qr_q_i32, target_type=pl.FP16, mode="round") - qr[:, qa + q1 : qa + q1 + QUANT_CHUNK] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") + for qa in pl.range(0, Q_LORA, QUANT_TILE): + qr_q_f32 = pl.cast(qr_bf16[:, qa : qa + QUANT_TILE], target_type=pl.FP32) + qr_q_scaled = pl.row_expand_mul(qr_q_f32, qr_scale_quant_t) + qr_q_i32 = pl.cast(qr_q_scaled, target_type=pl.INT32, mode="rint") + qr_q_half = pl.cast(qr_q_i32, target_type=pl.FP16, mode="round") + qr[:, qa : qa + QUANT_TILE] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") # Stage 3: W8A8C16 q_proj = qr_i8 @ wq_b, then dequantize to FP32. # qproj_matmul is GROUP-chunked (Opt J); qproj_dequant is decoupled into its own @@ -233,31 +158,25 @@ def qkv_proj_rope( # inside pl.range causes pypto AST to thread it through pl.parallel's init_values, # which fails SSA verification (see feedback_pypto_head_group_chunking_loop_carried.md). q_proj_fp32 = pl.create_tensor([T, H * HEAD_DIM], dtype=pl.FP32) - col_acc_all = pl.create_tensor([Q_PROJ_HEAD_BLOCKS * T, Q_PROJ_OUT_CHUNK], dtype=pl.INT32) - for hg in pl.parallel(0, Q_PROJ_HEAD_BLOCKS, Q_PROJ_GROUP): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_matmul"): - # Pre-declare to give pypto's loop-carried init_values threading a valid - # outer source; first matmul iter overwrites this. - col_acc = pl.create_tensor([T, Q_PROJ_OUT_CHUNK], dtype=pl.INT32) - for h_inner in pl.range(Q_PROJ_GROUP): - for qb in pl.pipeline(0, Q_PROJ_BLOCKS, stage=2): - qr_proj_col0 = qb * Q_PROJ_CHUNK - qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_CHUNK] - wq_chunk = wq_b[qr_proj_col0 : qr_proj_col0 + Q_PROJ_CHUNK, (hg + h_inner) * Q_PROJ_OUT_CHUNK : (hg + h_inner) * Q_PROJ_OUT_CHUNK + Q_PROJ_OUT_CHUNK] + for hg in pl.parallel(0, (H * HEAD_DIM) // Q_PROJ_OUT_TILE, 16): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj"): + col_acc = pl.create_tensor([T, Q_PROJ_OUT_TILE], dtype=pl.INT32) + for h_inner in pl.range(16): + for qb in pl.pipeline(0, Q_LORA // Q_PROJ_TILE, stage=2): + qr_proj_col0 = qb * Q_PROJ_TILE + qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE] + wq_chunk = wq_b[qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] if qr_proj_col0 == 0: col_acc = pl.matmul(qr_i8_chunk, wq_chunk, out_dtype=pl.INT32) else: col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) - col_acc_all[(hg + h_inner) * T : (hg + h_inner) * T + T, :] = col_acc - - for hbg in pl.parallel(0, Q_PROJ_HEAD_BLOCKS, Q_PROJ_DEQUANT_GROUP): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_dequant"): - for h_inner in pl.range(Q_PROJ_DEQUANT_GROUP): - col_acc_chunk = col_acc_all[(hbg + h_inner) * T : (hbg + h_inner) * T + T, :] - col_fp32 = pl.cast(col_acc_chunk, target_type=pl.FP32, mode="none") - w_scale = wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] - col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq), w_scale) - q_proj_fp32[:, (hbg + h_inner) * Q_PROJ_OUT_CHUNK : (hbg + h_inner) * Q_PROJ_OUT_CHUNK + Q_PROJ_OUT_CHUNK] = col_dequant + w_scale = wq_b_scale[hg + h_inner : hg + h_inner + 1, :] + for tc in pl.range(0, T, T_TILE): + col_acc_t = col_acc[tc : tc + T_TILE, :] + col_fp32 = pl.cast(col_acc_t, target_type=pl.FP32, mode="none") + qr_scale_dq_t = qr_scale_dq[tc : tc + T_TILE, :] + col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq_t), w_scale) + q_proj_fp32[tc : tc + T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant # Stage 4: per-head RMSNorm + RoPE on q. # Split into q_head_rms_nope and q_head_rope at T=128 — the fused @@ -277,33 +196,35 @@ def qkv_proj_rope( q_flat = pl.reshape(q, [T, H * HEAD_DIM]) q_head_inv_rms_all = pl.create_tensor([H, T], dtype=pl.FP32) q_rope_pair_stage = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.BF16) - for h in pl.parallel(0, H, 1): + for hg in pl.parallel(0, H, 2): with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_head_rms_nope"): - h0 = h * HEAD_DIM - q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for db in pl.range(HEAD_DIM // HEAD_CHUNK): - d0 = h0 + db * HEAD_CHUNK - q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_CHUNK] - q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) - q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) - q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) - q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms - - for nb in pl.range(NOPE_DIM // HEAD_CHUNK): - n0 = nb * HEAD_CHUNK - q_nope_chunk = q_proj_fp32[:, h0 + n0 : h0 + n0 + HEAD_CHUNK] - q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) - q_flat[:, h0 + n0 : h0 + n0 + HEAD_CHUNK] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") + for h_inner in pl.range(2): + h = hg + h_inner + h0 = h * HEAD_DIM + q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) + for db in pl.range(HEAD_DIM // HEAD_TILE): + d0 = h0 + db * HEAD_TILE + q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_TILE] + q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) + q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) + q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) + q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms + + for nb in pl.range(NOPE_DIM // HEAD_TILE): + n0 = nb * HEAD_TILE + q_nope_chunk = q_proj_fp32[:, h0 + n0 : h0 + n0 + HEAD_TILE] + q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) + q_flat[:, h0 + n0 : h0 + n0 + HEAD_TILE] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") # q_head_rope HEAD_GROUP-chunked (Opt K). Only one cross-iter loop-carried # tensor (q_rope_pair_stage), satisfying the success condition for chunked # parallel scopes. - for hg in pl.parallel(0, H, HEAD_GROUP): + for hg in pl.parallel(0, H, 8): with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_head_rope"): q_head_inv_rms_t = pl.create_tensor([T, 1], dtype=pl.FP32) rope_cos_fp32 = pl.cast(rope_cos[:, :ROPE_HALF], target_type=pl.FP32) rope_sin_fp32 = pl.cast(rope_sin[:, :ROPE_HALF], target_type=pl.FP32) - for h_inner in pl.range(HEAD_GROUP): + for h_inner in pl.range(8): q_head_inv_rms_t = pl.reshape(q_head_inv_rms_all[hg + h_inner : hg + h_inner + 1, :], [T, 1]) q_rope = q_proj_fp32[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] q_rope_norm = pl.row_expand_mul(q_rope, q_head_inv_rms_t) @@ -317,10 +238,10 @@ def qkv_proj_rope( q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] = q_rot_odd_bf16 # Stage 4d: HEAD_GROUP-chunked reassemble (cube) + write (vec). - for hg in pl.parallel(0, H, HEAD_GROUP): - q_rope_grp_fp32 = pl.create_tensor([HEAD_GROUP * T, ROPE_DIM], dtype=pl.FP32) + for hg in pl.parallel(0, H, 8): + q_rope_grp_fp32 = pl.create_tensor([8 * T, ROPE_DIM], dtype=pl.FP32) with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): - for h_inner in pl.range(HEAD_GROUP): + for h_inner in pl.range(8): even_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] odd_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] rot = pl.matmul( @@ -336,7 +257,7 @@ def qkv_proj_rope( q_rope_grp_fp32[h_inner * T : h_inner * T + T, :] = rot with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): - for h_inner in pl.range(HEAD_GROUP): + for h_inner in pl.range(8): rot_fp32 = q_rope_grp_fp32[h_inner * T : h_inner * T + T, :] q_flat[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] = pl.cast(rot_fp32, target_type=pl.BF16, mode="rint") @@ -345,43 +266,45 @@ def qkv_proj_rope( # Stage 5/6: kv = rms_norm(token_x @ wkv, gamma_ckv) + RoPE. # K loop uses pl.pipeline(stage=4) per Opt X (D_BLOCKS=32, enough iters). kv_fp32 = pl.create_tensor([T, HEAD_DIM], dtype=pl.FP32) - for kbg in pl.parallel(0, KV_BLOCKS, KV_PROJ_GROUP): + for kb in pl.parallel(0, HEAD_DIM // KV_TILE, 1): with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_proj_matmul"): - kv_acc = pl.create_tensor([T, KV_CHUNK], dtype=pl.FP32) - for k_inner in pl.range(KV_PROJ_GROUP): - kv_col0 = (kbg + k_inner) * KV_CHUNK - for db in pl.pipeline(0, D_BLOCKS, stage=4): - d0 = db * D_CHUNK - kv_x_chunk_bf16 = token_x_bf16[:, d0 : d0 + D_CHUNK] - wkv_chunk = wkv[d0 : d0 + D_CHUNK, kv_col0 : kv_col0 + KV_CHUNK] - if d0 == 0: - kv_acc = pl.matmul(kv_x_chunk_bf16, wkv_chunk, out_dtype=pl.FP32) - else: - kv_acc = pl.matmul_acc(kv_acc, kv_x_chunk_bf16, wkv_chunk) - kv_fp32[:, kv_col0 : kv_col0 + KV_CHUNK] = kv_acc + kv_acc = pl.create_tensor([T, KV_TILE], dtype=pl.FP32) + kv_col0 = kb * KV_TILE + for db in pl.pipeline(0, D // D_TILE, stage=2): + d0 = db * D_TILE + kv_x_chunk_bf16 = token_x_bf16[:, d0 : d0 + D_TILE] + wkv_chunk = wkv[d0 : d0 + D_TILE, kv_col0 : kv_col0 + KV_TILE] + if d0 == 0: + kv_acc = pl.matmul(kv_x_chunk_bf16, wkv_chunk, out_dtype=pl.FP32) + else: + kv_acc = pl.matmul_acc(kv_acc, kv_x_chunk_bf16, wkv_chunk) + kv_fp32[:, kv_col0 : kv_col0 + KV_TILE] = kv_acc - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rms"): + kv_inv_rms_tensor = pl.create_tensor([1, T], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rms_norm"): kv_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for kb in pl.range(KV_BLOCKS): - kv_sq_col0 = kb * KV_CHUNK - kv_chunk = kv_fp32[:, kv_sq_col0 : kv_sq_col0 + KV_CHUNK] + for kb in pl.range(HEAD_DIM // KV_TILE): + kv_sq_col0 = kb * KV_TILE + kv_chunk = kv_fp32[:, kv_sq_col0 : kv_sq_col0 + KV_TILE] kv_sq_sum = pl.add(kv_sq_sum, pl.reshape(pl.row_sum(pl.mul(kv_chunk, kv_chunk)), [1, T])) kv_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(kv_sq_sum, 1.0 / HEAD_DIM), EPS))) + kv_inv_rms_t = pl.reshape(kv_inv_rms, [T, 1]) + kv_inv_rms_tensor[0:1, :] = kv_inv_rms - kv_inv_rms_t = pl.reshape(kv_inv_rms, [T, 1]) - for nb in pl.parallel(0, NOPE_DIM // KV_CHUNK, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_norm_nope"): - n0 = nb * KV_CHUNK - kv_chunk = kv_fp32[:, n0 : n0 + KV_CHUNK] + for nb in pl.range(NOPE_DIM // KV_TILE): + n0 = nb * KV_TILE + kv_chunk = kv_fp32[:, n0 : n0 + KV_TILE] gamma_kv_chunk = pl.reshape( - pl.cast(gamma_ckv[n0 : n0 + KV_CHUNK], target_type=pl.FP32), - [1, KV_CHUNK], + pl.cast(gamma_ckv[n0 : n0 + KV_TILE], target_type=pl.FP32), + [1, KV_TILE], ) kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_chunk, kv_inv_rms_t), gamma_kv_chunk) - kv[:, n0 : n0 + KV_CHUNK] = pl.cast(kv_normed, target_type=pl.BF16, mode="rint") + kv[:, n0 : n0 + KV_TILE] = pl.cast(kv_normed, target_type=pl.BF16, mode="rint") + kv_rot_even_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) kv_rot_odd_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_apply"): + kv_inv_rms_t = pl.reshape(kv_inv_rms_tensor[0:1, :], [T, 1]) kv_rope = kv_fp32[:, NOPE_DIM : NOPE_DIM + ROPE_DIM] gamma_rope = pl.reshape( pl.cast(gamma_ckv[NOPE_DIM : NOPE_DIM + ROPE_DIM], target_type=pl.FP32), @@ -399,21 +322,21 @@ def qkv_proj_rope( kv_rope_full = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): - for rope_col in pl.range(0, ROPE_DIM, ROPE_CHUNK): + for rope_col in pl.range(0, ROPE_DIM, ROPE_TILE): pair_col = rope_col // 2 - kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] - kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] + kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_TILE] + kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_TILE] kv_rot_chunk = pl.matmul( kv_rot_even_chunk, - even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], + even_select_t[pair_col : pair_col + ROPE_PAIR_TILE, rope_col : rope_col + ROPE_TILE], out_dtype=pl.FP32, ) kv_rot_chunk = pl.matmul_acc( kv_rot_chunk, kv_rot_odd_chunk, - odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], + odd_select_t[pair_col : pair_col + ROPE_PAIR_TILE, rope_col : rope_col + ROPE_TILE], ) - kv_rope_full[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk + kv_rope_full[:, rope_col : rope_col + ROPE_TILE] = kv_rot_chunk with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): kv[:, NOPE_DIM : NOPE_DIM + ROPE_DIM] = pl.cast(kv_rope_full, target_type=pl.BF16, mode="rint") @@ -427,7 +350,7 @@ def qkv_proj_rope_test( norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], @@ -582,14 +505,14 @@ def init_gamma_ckv(): wq_b_bf16 = init_wq_b().to(torch.bfloat16) wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16) - wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK) + wq_b_scale = wq_b_scale.view((H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE) return [ TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), TensorSpec("norm_w", [D], torch.float32, init_value=init_norm_w), TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a), TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8), - TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale), + TensorSpec("wq_b_scale", [(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], torch.float32, init_value=lambda: wq_b_scale), TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv), TensorSpec("rope_cos", [T, ROPE_DIM], torch.bfloat16, init_value=init_cos), TensorSpec("rope_sin", [T, ROPE_DIM], torch.bfloat16, init_value=init_sin), From 3443c0fe0f885b9de0a32c2e0c2279e78821bf68 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 14:16:45 +0800 Subject: [PATCH 2/8] Refactor: dsv4 qkv_proj_rope with pl.spmd and qr quant fusion - Convert all parallel scopes (attn_norm, qr_proj_matmul, qr_rms_norm, qproj, q_head_rms_nope, q_head_rope, q_rope_reassemble, q_rope_write, kv_proj_matmul) from pl.parallel + pl.at to pl.spmd dispatch - Halve kv_proj_matmul task count (group=2 inner pl.range) - Fuse qr_rms_norm + qr_quant into one T-tiled spmd scope using a two-pass design: pass 1 computes amax without GM staging, pass 2 recomputes norm and quantizes; drops qr_bf16 and qr_scale_dq GM intermediates and keeps kernel outputs at 2 to sidestep the pypto multi-InOut OptimizeOrchTensors alias bug - qproj_dequant reads qr_scale directly (same values as qr_scale_dq) - Split q_rope reassemble/write into two spmd scopes sharing a GM staging tensor (cube+vec mixing in one scope blows Vec UB) --- models/deepseek/v4/qkv_proj_rope.py | 311 ++++++++++++++-------------- 1 file changed, 159 insertions(+), 152 deletions(-) diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index cc2821fd..1ac460ba 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -67,86 +67,92 @@ def qkv_proj_rope( ): x_flat = pl.reshape(x, [T, D]) - # Stage 0.1+0.2: attn_norm fused RMS + apply, T-tiled parallel (per-token reduction). + # Stage 0.1+0.2: attn_norm fused RMS + apply, T-tiled SPMD (per-token reduction). token_x_bf16 = pl.create_tensor([T, D], dtype=pl.BF16) - for tg in pl.parallel(0, T, T_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm"): - x_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) - for rms_db in pl.range(D // D_TILE): - rms_d0 = rms_db * D_TILE - rms_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, rms_d0 : rms_d0 + D_TILE], target_type=pl.FP32) - x_sq_sum = pl.add(x_sq_sum, pl.reshape(pl.row_sum(pl.mul(rms_x_chunk, rms_x_chunk)), [1, T_TILE])) - x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) - x_inv_rms_t = pl.reshape(x_inv_rms, [T_TILE, 1]) - for apply_db in pl.range(D // D_TILE): - apply_d0 = apply_db * D_TILE - apply_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE], target_type=pl.FP32) - norm_w_chunk = pl.reshape(norm_w[apply_d0 : apply_d0 + D_TILE], [1, D_TILE]) - x_normed = pl.col_expand_mul(pl.row_expand_mul(apply_x_chunk, x_inv_rms_t), norm_w_chunk) - token_x_bf16[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") + for tg_idx in pl.spmd(T // T_TILE, name_hint="attn_norm"): + tg = tg_idx * T_TILE + x_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + for rms_db in pl.range(D // D_TILE): + rms_d0 = rms_db * D_TILE + rms_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, rms_d0 : rms_d0 + D_TILE], target_type=pl.FP32) + x_sq_sum = pl.add(x_sq_sum, pl.reshape(pl.row_sum(pl.mul(rms_x_chunk, rms_x_chunk)), [1, T_TILE])) + x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) + x_inv_rms_t = pl.reshape(x_inv_rms, [T_TILE, 1]) + for apply_db in pl.range(D // D_TILE): + apply_d0 = apply_db * D_TILE + apply_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE], target_type=pl.FP32) + norm_w_chunk = pl.reshape(norm_w[apply_d0 : apply_d0 + D_TILE], [1, D_TILE]) + x_normed = pl.col_expand_mul(pl.row_expand_mul(apply_x_chunk, x_inv_rms_t), norm_w_chunk) + token_x_bf16[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") # Stage 1/2.1: qr = rms_norm(token_x @ wq_a, gamma_cq). # K loop uses pl.pipeline(stage=4) for 4-deep ping-pong on the D=4096 input # projection (D_BLOCKS=32, sufficient iter count for 4-stage replication). qr_fp32 = pl.create_tensor([T, Q_LORA], dtype=pl.FP32) - for qbg in pl.parallel(0, Q_LORA // Q_LORA_TILE, 2): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_proj_matmul"): - for q_inner in pl.range(2): - q_a_col0 = (qbg + q_inner) * Q_LORA_TILE - q_acc = pl.create_tensor([T, Q_LORA_TILE], dtype=pl.FP32) - for db in pl.pipeline(0, D // D_TILE, stage=2): - qr_d0 = db * D_TILE - q_x_chunk_bf16 = token_x_bf16[:, qr_d0 : qr_d0 + D_TILE] - w_chunk = wq_a[qr_d0 : qr_d0 + D_TILE, q_a_col0 : q_a_col0 + Q_LORA_TILE] - if qr_d0 == 0: - q_acc = pl.matmul(q_x_chunk_bf16, w_chunk, out_dtype=pl.FP32) - else: - q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) - qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_TILE] = q_acc - - # Stage 2.1+2.2: fused qr_rms + qr_norm_apply + per-token amax, T-tiled parallel. - qr_bf16 = pl.create_tensor([T, Q_LORA], dtype=pl.BF16) - qr_amax_tensor = pl.create_tensor([1, T], dtype=pl.FP32) - for tg in pl.parallel(0, T, T_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_rms_norm"): - qr_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) - for qr_rms_qb in pl.range(Q_LORA // Q_LORA_TILE): - qr_rms_col0 = qr_rms_qb * Q_LORA_TILE - qr_rms_chunk = qr_fp32[tg : tg + T_TILE, qr_rms_col0 : qr_rms_col0 + Q_LORA_TILE] - qr_sq_sum = pl.add(qr_sq_sum, pl.reshape(pl.row_sum(pl.mul(qr_rms_chunk, qr_rms_chunk)), [1, T_TILE])) - qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) - qr_inv_rms_t = pl.reshape(qr_inv_rms, [T_TILE, 1]) - - qr_tile_amax = pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) - for qb in pl.range(Q_LORA // Q_LORA_TILE): - qr_norm_col0 = qb * Q_LORA_TILE - qr_norm_chunk = qr_fp32[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] - gamma_chunk = pl.reshape( - pl.cast(gamma_cq[qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE], target_type=pl.FP32), - [1, Q_LORA_TILE], - ) - qr_normed = pl.col_expand_mul(pl.row_expand_mul(qr_norm_chunk, qr_inv_rms_t), gamma_chunk) - qr_normed_bf16 = pl.cast(qr_normed, target_type=pl.BF16, mode="rint") - qr_bf16[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] = qr_normed_bf16 - qr_norm_amax_f32 = pl.cast(qr_normed_bf16, target_type=pl.FP32) - qr_norm_amax_abs = pl.maximum(qr_norm_amax_f32, pl.neg(qr_norm_amax_f32)) - qr_tile_amax = pl.maximum(qr_tile_amax, pl.reshape(pl.row_max(qr_norm_amax_abs), [1, T_TILE])) - qr_amax_tensor[0:1, tg : tg + T_TILE] = qr_tile_amax - - # Stage 2.3: fused INT8 quant scale + apply (single-scope serial). - qr_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_quant"): - qr_amax = qr_amax_tensor[0:1, :] - qr_scale_quant_row = pl.div(pl.full([1, T], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_amax) - qr_scale_quant_t = pl.reshape(qr_scale_quant_row, [T, 1]) - qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]) - qr_scale[:, :] = qr_scale_dq + for qbg_idx in pl.spmd((Q_LORA // Q_LORA_TILE) // 2, name_hint="qr_proj_matmul"): + qbg = qbg_idx * 2 + for q_inner in pl.range(2): + q_a_col0 = (qbg + q_inner) * Q_LORA_TILE + q_acc = pl.create_tensor([T, Q_LORA_TILE], dtype=pl.FP32) + for db in pl.pipeline(0, D // D_TILE, stage=2): + qr_d0 = db * D_TILE + q_x_chunk_bf16 = token_x_bf16[:, qr_d0 : qr_d0 + D_TILE] + w_chunk = wq_a[qr_d0 : qr_d0 + D_TILE, q_a_col0 : q_a_col0 + Q_LORA_TILE] + if qr_d0 == 0: + q_acc = pl.matmul(q_x_chunk_bf16, w_chunk, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) + qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_TILE] = q_acc + + # Stage 2.1+2.2+2.3: fused qr_rms + norm + amax + INT8 quant, T-tiled SPMD. + # Two-pass within block: first pass computes RMS + amax (no qr_bf16 write); + # second pass recomputes norm and quantizes directly. Saves the qr_bf16 GM + # staging tensor and reduces kernel outputs to 2 (qr_scale, qr), avoiding the + # pypto multi-InOut OptimizeOrchTensors alias bug. + for tg_idx in pl.spmd(T // T_TILE, name_hint="qr_rms_norm_quant"): + tg = tg_idx * T_TILE + qr_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + qr_tile_amax = pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) + for qr_rms_qb in pl.range(Q_LORA // Q_LORA_TILE): + qr_rms_col0 = qr_rms_qb * Q_LORA_TILE + qr_rms_chunk = qr_fp32[tg : tg + T_TILE, qr_rms_col0 : qr_rms_col0 + Q_LORA_TILE] + qr_sq_sum = pl.add(qr_sq_sum, pl.reshape(pl.row_sum(pl.mul(qr_rms_chunk, qr_rms_chunk)), [1, T_TILE])) + qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) + qr_inv_rms_t = pl.reshape(qr_inv_rms, [T_TILE, 1]) + + # Pass 1: norm + amax accumulation (no GM write of the normalized values). + for qb in pl.range(Q_LORA // Q_LORA_TILE): + qr_norm_col0 = qb * Q_LORA_TILE + qr_norm_chunk = qr_fp32[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] + gamma_chunk = pl.reshape( + pl.cast(gamma_cq[qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE], target_type=pl.FP32), + [1, Q_LORA_TILE], + ) + qr_normed = pl.col_expand_mul(pl.row_expand_mul(qr_norm_chunk, qr_inv_rms_t), gamma_chunk) + qr_normed_bf16 = pl.cast(qr_normed, target_type=pl.BF16, mode="rint") + qr_norm_amax_f32 = pl.cast(qr_normed_bf16, target_type=pl.FP32) + qr_norm_amax_abs = pl.maximum(qr_norm_amax_f32, pl.neg(qr_norm_amax_f32)) + qr_tile_amax = pl.maximum(qr_tile_amax, pl.reshape(pl.row_max(qr_norm_amax_abs), [1, T_TILE])) + + qr_scale_quant_row = pl.div(pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_tile_amax) + qr_scale_quant_t = pl.reshape(qr_scale_quant_row, [T_TILE, 1]) + qr_tile_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T_TILE, 1]) + qr_scale[tg : tg + T_TILE, :] = qr_tile_scale_dq + + # Pass 2: recompute norm and quantize directly to qr. for qa in pl.range(0, Q_LORA, QUANT_TILE): - qr_q_f32 = pl.cast(qr_bf16[:, qa : qa + QUANT_TILE], target_type=pl.FP32) + qr_chunk = qr_fp32[tg : tg + T_TILE, qa : qa + QUANT_TILE] + gamma_q_chunk = pl.reshape( + pl.cast(gamma_cq[qa : qa + QUANT_TILE], target_type=pl.FP32), + [1, QUANT_TILE], + ) + qr_q_normed = pl.col_expand_mul(pl.row_expand_mul(qr_chunk, qr_inv_rms_t), gamma_q_chunk) + qr_q_normed_bf16 = pl.cast(qr_q_normed, target_type=pl.BF16, mode="rint") + qr_q_f32 = pl.cast(qr_q_normed_bf16, target_type=pl.FP32) qr_q_scaled = pl.row_expand_mul(qr_q_f32, qr_scale_quant_t) qr_q_i32 = pl.cast(qr_q_scaled, target_type=pl.INT32, mode="rint") qr_q_half = pl.cast(qr_q_i32, target_type=pl.FP16, mode="round") - qr[:, qa : qa + QUANT_TILE] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") + qr[tg : tg + T_TILE, qa : qa + QUANT_TILE] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") # Stage 3: W8A8C16 q_proj = qr_i8 @ wq_b, then dequantize to FP32. # qproj_matmul is GROUP-chunked (Opt J); qproj_dequant is decoupled into its own @@ -158,25 +164,25 @@ def qkv_proj_rope( # inside pl.range causes pypto AST to thread it through pl.parallel's init_values, # which fails SSA verification (see feedback_pypto_head_group_chunking_loop_carried.md). q_proj_fp32 = pl.create_tensor([T, H * HEAD_DIM], dtype=pl.FP32) - for hg in pl.parallel(0, (H * HEAD_DIM) // Q_PROJ_OUT_TILE, 16): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj"): - col_acc = pl.create_tensor([T, Q_PROJ_OUT_TILE], dtype=pl.INT32) - for h_inner in pl.range(16): - for qb in pl.pipeline(0, Q_LORA // Q_PROJ_TILE, stage=2): - qr_proj_col0 = qb * Q_PROJ_TILE - qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE] - wq_chunk = wq_b[qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] - if qr_proj_col0 == 0: - col_acc = pl.matmul(qr_i8_chunk, wq_chunk, out_dtype=pl.INT32) - else: - col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) - w_scale = wq_b_scale[hg + h_inner : hg + h_inner + 1, :] - for tc in pl.range(0, T, T_TILE): - col_acc_t = col_acc[tc : tc + T_TILE, :] - col_fp32 = pl.cast(col_acc_t, target_type=pl.FP32, mode="none") - qr_scale_dq_t = qr_scale_dq[tc : tc + T_TILE, :] - col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq_t), w_scale) - q_proj_fp32[tc : tc + T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant + for hg_idx in pl.spmd(((H * HEAD_DIM) // Q_PROJ_OUT_TILE) // 16, name_hint="qproj"): + hg = hg_idx * 16 + col_acc = pl.create_tensor([T, Q_PROJ_OUT_TILE], dtype=pl.INT32) + for h_inner in pl.range(16): + for qb in pl.pipeline(0, Q_LORA // Q_PROJ_TILE, stage=2): + qr_proj_col0 = qb * Q_PROJ_TILE + qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE] + wq_chunk = wq_b[qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] + if qr_proj_col0 == 0: + col_acc = pl.matmul(qr_i8_chunk, wq_chunk, out_dtype=pl.INT32) + else: + col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) + w_scale = wq_b_scale[hg + h_inner : hg + h_inner + 1, :] + for tc in pl.range(0, T, T_TILE): + col_acc_t = col_acc[tc : tc + T_TILE, :] + col_fp32 = pl.cast(col_acc_t, target_type=pl.FP32, mode="none") + qr_scale_dq_t = qr_scale[tc : tc + T_TILE, :] + col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq_t), w_scale) + q_proj_fp32[tc : tc + T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant # Stage 4: per-head RMSNorm + RoPE on q. # Split into q_head_rms_nope and q_head_rope at T=128 — the fused @@ -196,80 +202,81 @@ def qkv_proj_rope( q_flat = pl.reshape(q, [T, H * HEAD_DIM]) q_head_inv_rms_all = pl.create_tensor([H, T], dtype=pl.FP32) q_rope_pair_stage = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.BF16) - for hg in pl.parallel(0, H, 2): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_head_rms_nope"): - for h_inner in pl.range(2): - h = hg + h_inner - h0 = h * HEAD_DIM - q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for db in pl.range(HEAD_DIM // HEAD_TILE): - d0 = h0 + db * HEAD_TILE - q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_TILE] - q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) - q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) - q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) - q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms - - for nb in pl.range(NOPE_DIM // HEAD_TILE): - n0 = nb * HEAD_TILE - q_nope_chunk = q_proj_fp32[:, h0 + n0 : h0 + n0 + HEAD_TILE] - q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) - q_flat[:, h0 + n0 : h0 + n0 + HEAD_TILE] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") + for hg_idx in pl.spmd(H // 2, name_hint="q_head_rms_nope"): + hg = hg_idx * 2 + for h_inner in pl.range(2): + h = hg + h_inner + h0 = h * HEAD_DIM + q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) + for db in pl.range(HEAD_DIM // HEAD_TILE): + d0 = h0 + db * HEAD_TILE + q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_TILE] + q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) + q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) + q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) + q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms + + for nb in pl.range(NOPE_DIM // HEAD_TILE): + n0 = nb * HEAD_TILE + q_nope_chunk = q_proj_fp32[:, h0 + n0 : h0 + n0 + HEAD_TILE] + q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) + q_flat[:, h0 + n0 : h0 + n0 + HEAD_TILE] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") # q_head_rope HEAD_GROUP-chunked (Opt K). Only one cross-iter loop-carried # tensor (q_rope_pair_stage), satisfying the success condition for chunked # parallel scopes. - for hg in pl.parallel(0, H, 8): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_head_rope"): - q_head_inv_rms_t = pl.create_tensor([T, 1], dtype=pl.FP32) - rope_cos_fp32 = pl.cast(rope_cos[:, :ROPE_HALF], target_type=pl.FP32) - rope_sin_fp32 = pl.cast(rope_sin[:, :ROPE_HALF], target_type=pl.FP32) - for h_inner in pl.range(8): - q_head_inv_rms_t = pl.reshape(q_head_inv_rms_all[hg + h_inner : hg + h_inner + 1, :], [T, 1]) - q_rope = q_proj_fp32[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] - q_rope_norm = pl.row_expand_mul(q_rope, q_head_inv_rms_t) - q_even = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) - q_odd = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) - q_rot_even = pl.sub(pl.mul(q_even, rope_cos_fp32), pl.mul(q_odd, rope_sin_fp32)) - q_rot_odd = pl.add(pl.mul(q_even, rope_sin_fp32), pl.mul(q_odd, rope_cos_fp32)) - q_rot_even_bf16 = pl.cast(q_rot_even, target_type=pl.BF16, mode="rint") - q_rot_odd_bf16 = pl.cast(q_rot_odd, target_type=pl.BF16, mode="rint") - q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] = q_rot_even_bf16 - q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] = q_rot_odd_bf16 - - # Stage 4d: HEAD_GROUP-chunked reassemble (cube) + write (vec). - for hg in pl.parallel(0, H, 8): - q_rope_grp_fp32 = pl.create_tensor([8 * T, ROPE_DIM], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): - for h_inner in pl.range(8): - even_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] - odd_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] - rot = pl.matmul( - even_chunk, - even_select_t[:, :], - out_dtype=pl.FP32, - ) - rot = pl.matmul_acc( - rot, - odd_chunk, - odd_select_t[:, :], - ) - q_rope_grp_fp32[h_inner * T : h_inner * T + T, :] = rot - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): - for h_inner in pl.range(8): - rot_fp32 = q_rope_grp_fp32[h_inner * T : h_inner * T + T, :] - q_flat[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] = pl.cast(rot_fp32, target_type=pl.BF16, mode="rint") + for hg_idx in pl.spmd(H // 8, name_hint="q_head_rope"): + hg = hg_idx * 8 + q_head_inv_rms_t = pl.create_tensor([T, 1], dtype=pl.FP32) + rope_cos_fp32 = pl.cast(rope_cos[:, :ROPE_HALF], target_type=pl.FP32) + rope_sin_fp32 = pl.cast(rope_sin[:, :ROPE_HALF], target_type=pl.FP32) + for h_inner in pl.range(8): + q_head_inv_rms_t = pl.reshape(q_head_inv_rms_all[hg + h_inner : hg + h_inner + 1, :], [T, 1]) + q_rope = q_proj_fp32[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] + q_rope_norm = pl.row_expand_mul(q_rope, q_head_inv_rms_t) + q_even = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) + q_odd = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) + q_rot_even = pl.sub(pl.mul(q_even, rope_cos_fp32), pl.mul(q_odd, rope_sin_fp32)) + q_rot_odd = pl.add(pl.mul(q_even, rope_sin_fp32), pl.mul(q_odd, rope_cos_fp32)) + q_rot_even_bf16 = pl.cast(q_rot_even, target_type=pl.BF16, mode="rint") + q_rot_odd_bf16 = pl.cast(q_rot_odd, target_type=pl.BF16, mode="rint") + q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] = q_rot_even_bf16 + q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] = q_rot_odd_bf16 + + # Stage 4d: reassemble (cube) + write (vec) — two SPMD scopes share a GM staging tensor. + q_rope_grp_fp32 = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.FP32) + for hg_idx in pl.spmd(H // 8, name_hint="q_rope_reassemble"): + hg = hg_idx * 8 + for h_inner in pl.range(8): + even_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] + odd_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] + rot = pl.matmul( + even_chunk, + even_select_t[:, :], + out_dtype=pl.FP32, + ) + rot = pl.matmul_acc( + rot, + odd_chunk, + odd_select_t[:, :], + ) + q_rope_grp_fp32[(hg + h_inner) * T : (hg + h_inner) * T + T, :] = rot + + for hg_idx in pl.spmd(H // 8, name_hint="q_rope_write"): + hg = hg_idx * 8 + for h_inner in pl.range(8): + rot_fp32 = q_rope_grp_fp32[(hg + h_inner) * T : (hg + h_inner) * T + T, :] + q_flat[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] = pl.cast(rot_fp32, target_type=pl.BF16, mode="rint") q = pl.reshape(q_flat, [T, H, HEAD_DIM]) # Stage 5/6: kv = rms_norm(token_x @ wkv, gamma_ckv) + RoPE. # K loop uses pl.pipeline(stage=4) per Opt X (D_BLOCKS=32, enough iters). kv_fp32 = pl.create_tensor([T, HEAD_DIM], dtype=pl.FP32) - for kb in pl.parallel(0, HEAD_DIM // KV_TILE, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_proj_matmul"): + for kbg in pl.spmd((HEAD_DIM // KV_TILE) // 2, name_hint="kv_proj_matmul"): + for k_inner in pl.range(2): kv_acc = pl.create_tensor([T, KV_TILE], dtype=pl.FP32) - kv_col0 = kb * KV_TILE + kv_col0 = (kbg * 2 + k_inner) * KV_TILE for db in pl.pipeline(0, D // D_TILE, stage=2): d0 = db * D_TILE kv_x_chunk_bf16 = token_x_bf16[:, d0 : d0 + D_TILE] From 6c8f229db1b26dafece17c6e838d6aa9d35ff932 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 14:27:19 +0800 Subject: [PATCH 3/8] Refactor: trim qkv_proj_rope comments and flatten wq_b_scale shape - Drop stage labels, tuning-log references, and bug-rationale notes; keep only short, code-functional comments - Remove table-aligned constants and tensor signatures - Flatten wq_b_scale from [H_BLOCKS, OUT_TILE] to [H * HEAD_DIM] so the external shape no longer depends on internal tiling; reshape per-tile at the qproj use site - Move q_rope_pair_stage allocation next to its first writer - Drop a stale assert (H * HEAD_DIM) % (HEAD_TILE * 8) (recombines unrelated constraints already implied by the loop bounds) --- models/deepseek/v4/qkv_proj_rope.py | 169 +++++++++++----------------- 1 file changed, 64 insertions(+), 105 deletions(-) diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index 1ac460ba..bfffde49 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -16,58 +16,52 @@ # model config -B = DECODE_BATCH -S = DECODE_SEQ -T = B * S -D = M.hidden_size -H = M.num_attention_heads -HEAD_DIM = M.head_dim -ROPE_DIM = M.qk_rope_head_dim -ROPE_HALF = ROPE_DIM // 2 -NOPE_DIM = M.nope_head_dim -Q_LORA = M.q_lora_rank -EPS = M.rms_norm_eps +B = DECODE_BATCH +S = DECODE_SEQ +T = B * S +D = M.hidden_size +H = M.num_attention_heads +HEAD_DIM = M.head_dim +ROPE_DIM = M.qk_rope_head_dim +ROPE_HALF = ROPE_DIM // 2 +NOPE_DIM = M.nope_head_dim +Q_LORA = M.q_lora_rank +EPS = M.rms_norm_eps # tiling -# Group constants control pl.parallel(0, N, GROUP) + pl.range(GROUP) folding — -# how many logical chunks are fused into one InCore task. See Opt J/K/L/N/O/P -# in docs/dsv4-qkv-proj-rope-perf-tuning.md for the per-scope sweep results. -ROPE_TILE = 64 +ROPE_TILE = 64 ROPE_PAIR_TILE = ROPE_TILE // 2 -HEAD_TILE = 64 +HEAD_TILE = 64 Q_PROJ_OUT_TILE = 128 -Q_PROJ_TILE = 512 # K-tile +Q_PROJ_TILE = 512 Q_LORA_TILE = 32 -D_TILE = 128 -KV_TILE = 32 +D_TILE = 128 +KV_TILE = 32 QUANT_TILE = 32 -T_TILE = 16 # T-axis sub-tile for qproj dequant (keeps cube+vec fused scope under Vec UB) -assert (H * HEAD_DIM) % (HEAD_TILE * 8) == 0, \ - "HEAD_BLOCKS must be divisible by 8" +T_TILE = 16 @pl.jit.inline def qkv_proj_rope( - x: pl.Tensor[[B, S, D], pl.BF16], - norm_w: pl.Tensor[[D], pl.FP32], - wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], - wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], pl.FP32], - wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], - rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], - rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], + x: pl.Tensor[[B, S, D], pl.BF16], + norm_w: pl.Tensor[[D], pl.FP32], + wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], + wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], + wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], + rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], + rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], - odd_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], - gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], - gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], - q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], - kv: pl.Tensor[[T, HEAD_DIM], pl.BF16], - qr: pl.Tensor[[T, Q_LORA], pl.INT8], - qr_scale: pl.Tensor[[T, 1], pl.FP32], + odd_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], + gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], + gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], + q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], + kv: pl.Tensor[[T, HEAD_DIM], pl.BF16], + qr: pl.Tensor[[T, Q_LORA], pl.INT8], + qr_scale: pl.Tensor[[T, 1], pl.FP32], ): x_flat = pl.reshape(x, [T, D]) - # Stage 0.1+0.2: attn_norm fused RMS + apply, T-tiled SPMD (per-token reduction). token_x_bf16 = pl.create_tensor([T, D], dtype=pl.BF16) for tg_idx in pl.spmd(T // T_TILE, name_hint="attn_norm"): tg = tg_idx * T_TILE @@ -85,9 +79,6 @@ def qkv_proj_rope( x_normed = pl.col_expand_mul(pl.row_expand_mul(apply_x_chunk, x_inv_rms_t), norm_w_chunk) token_x_bf16[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") - # Stage 1/2.1: qr = rms_norm(token_x @ wq_a, gamma_cq). - # K loop uses pl.pipeline(stage=4) for 4-deep ping-pong on the D=4096 input - # projection (D_BLOCKS=32, sufficient iter count for 4-stage replication). qr_fp32 = pl.create_tensor([T, Q_LORA], dtype=pl.FP32) for qbg_idx in pl.spmd((Q_LORA // Q_LORA_TILE) // 2, name_hint="qr_proj_matmul"): qbg = qbg_idx * 2 @@ -104,11 +95,7 @@ def qkv_proj_rope( q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_TILE] = q_acc - # Stage 2.1+2.2+2.3: fused qr_rms + norm + amax + INT8 quant, T-tiled SPMD. - # Two-pass within block: first pass computes RMS + amax (no qr_bf16 write); - # second pass recomputes norm and quantizes directly. Saves the qr_bf16 GM - # staging tensor and reduces kernel outputs to 2 (qr_scale, qr), avoiding the - # pypto multi-InOut OptimizeOrchTensors alias bug. + # Two passes per block: pass 1 computes amax; pass 2 recomputes norm and quantizes. for tg_idx in pl.spmd(T // T_TILE, name_hint="qr_rms_norm_quant"): tg = tg_idx * T_TILE qr_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) @@ -120,7 +107,6 @@ def qkv_proj_rope( qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) qr_inv_rms_t = pl.reshape(qr_inv_rms, [T_TILE, 1]) - # Pass 1: norm + amax accumulation (no GM write of the normalized values). for qb in pl.range(Q_LORA // Q_LORA_TILE): qr_norm_col0 = qb * Q_LORA_TILE qr_norm_chunk = qr_fp32[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] @@ -139,7 +125,6 @@ def qkv_proj_rope( qr_tile_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T_TILE, 1]) qr_scale[tg : tg + T_TILE, :] = qr_tile_scale_dq - # Pass 2: recompute norm and quantize directly to qr. for qa in pl.range(0, Q_LORA, QUANT_TILE): qr_chunk = qr_fp32[tg : tg + T_TILE, qa : qa + QUANT_TILE] gamma_q_chunk = pl.reshape( @@ -154,15 +139,6 @@ def qkv_proj_rope( qr_q_half = pl.cast(qr_q_i32, target_type=pl.FP16, mode="round") qr[tg : tg + T_TILE, qa : qa + QUANT_TILE] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") - # Stage 3: W8A8C16 q_proj = qr_i8 @ wq_b, then dequantize to FP32. - # qproj_matmul is GROUP-chunked (Opt J); qproj_dequant is decoupled into its own - # outer pl.parallel with a larger DEQUANT_GROUP (Opt P), fed by a global INT32 - # staging buffer col_acc_all (16 MB at T=128). Decoupling lets dequant pick its - # own task size without forcing matmul to do the same — Opt J showed that - # matmul GRP=16 caused dispatcher contention upstream. - # `(hg + h_inner) * X` is inlined everywhere — binding it to a Python local - # inside pl.range causes pypto AST to thread it through pl.parallel's init_values, - # which fails SSA verification (see feedback_pypto_head_group_chunking_loop_carried.md). q_proj_fp32 = pl.create_tensor([T, H * HEAD_DIM], dtype=pl.FP32) for hg_idx in pl.spmd(((H * HEAD_DIM) // Q_PROJ_OUT_TILE) // 16, name_hint="qproj"): hg = hg_idx * 16 @@ -176,7 +152,8 @@ def qkv_proj_rope( col_acc = pl.matmul(qr_i8_chunk, wq_chunk, out_dtype=pl.INT32) else: col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) - w_scale = wq_b_scale[hg + h_inner : hg + h_inner + 1, :] + w_col0 = (hg + h_inner) * Q_PROJ_OUT_TILE + w_scale = pl.reshape(wq_b_scale[w_col0 : w_col0 + Q_PROJ_OUT_TILE], [1, Q_PROJ_OUT_TILE]) for tc in pl.range(0, T, T_TILE): col_acc_t = col_acc[tc : tc + T_TILE, :] col_fp32 = pl.cast(col_acc_t, target_type=pl.FP32, mode="none") @@ -184,24 +161,10 @@ def qkv_proj_rope( col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq_t), w_scale) q_proj_fp32[tc : tc + T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant - # Stage 4: per-head RMSNorm + RoPE on q. - # Split into q_head_rms_nope and q_head_rope at T=128 — the fused - # [RMS+NOPE+RoPE] scope holds ~7 FP32 [T, ROPE_HALF|ROPE_DIM] tensors in the - # RoPE block and exceeds the 192KB Vec UB. inv_rms crosses the boundary via - # a [H, T] FP32 staging tensor. - # - # q_head_rms_nope stays at pl.parallel(0, H, 1) — fine-grained: 64 tasks - # saturate the 48 AIV cores, span is already optimal. HEAD_GROUP chunking - # was tried (Opt M) and reverted; see perf-tuning doc. - # - # q_head_rope/reassemble/write are HEAD_GROUP-chunked. q_head_rope writes a - # cross-head staging tensor q_rope_pair_stage [H*T, ROPE_DIM]; the ROPE_DIM - # trailing axis (not ROPE_HALF) is intentional — pypto's orch-tensor optimizer - # would otherwise alias it to the BF16 [T, ROPE_HALF] kv_rot_*_tmp temps later - # in the function, triggering a known pypto codegen bug. + # Per-head RMS, NOPE projection, and RoPE rotation, staged through + # q_head_inv_rms_all and q_rope_pair_stage. q_flat = pl.reshape(q, [T, H * HEAD_DIM]) q_head_inv_rms_all = pl.create_tensor([H, T], dtype=pl.FP32) - q_rope_pair_stage = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.BF16) for hg_idx in pl.spmd(H // 2, name_hint="q_head_rms_nope"): hg = hg_idx * 2 for h_inner in pl.range(2): @@ -222,9 +185,7 @@ def qkv_proj_rope( q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) q_flat[:, h0 + n0 : h0 + n0 + HEAD_TILE] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") - # q_head_rope HEAD_GROUP-chunked (Opt K). Only one cross-iter loop-carried - # tensor (q_rope_pair_stage), satisfying the success condition for chunked - # parallel scopes. + q_rope_pair_stage = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.BF16) for hg_idx in pl.spmd(H // 8, name_hint="q_head_rope"): hg = hg_idx * 8 q_head_inv_rms_t = pl.create_tensor([T, 1], dtype=pl.FP32) @@ -243,7 +204,7 @@ def qkv_proj_rope( q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] = q_rot_even_bf16 q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] = q_rot_odd_bf16 - # Stage 4d: reassemble (cube) + write (vec) — two SPMD scopes share a GM staging tensor. + # Reassemble interleaved pairs (cube matmul) and write back to q_flat (vec cast). q_rope_grp_fp32 = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.FP32) for hg_idx in pl.spmd(H // 8, name_hint="q_rope_reassemble"): hg = hg_idx * 8 @@ -270,8 +231,6 @@ def qkv_proj_rope( q = pl.reshape(q_flat, [T, H, HEAD_DIM]) - # Stage 5/6: kv = rms_norm(token_x @ wkv, gamma_ckv) + RoPE. - # K loop uses pl.pipeline(stage=4) per Opt X (D_BLOCKS=32, enough iters). kv_fp32 = pl.create_tensor([T, HEAD_DIM], dtype=pl.FP32) for kbg in pl.spmd((HEAD_DIM // KV_TILE) // 2, name_hint="kv_proj_matmul"): for k_inner in pl.range(2): @@ -353,22 +312,22 @@ def qkv_proj_rope( @pl.jit def qkv_proj_rope_test( - x: pl.Tensor[[B, S, D], pl.BF16], - norm_w: pl.Tensor[[D], pl.FP32], - wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], - wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], pl.FP32], - wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], - rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], - rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], + x: pl.Tensor[[B, S, D], pl.BF16], + norm_w: pl.Tensor[[D], pl.FP32], + wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], + wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], + wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], + rope_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], + rope_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], - odd_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], - gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], - gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], - q: pl.Out[pl.Tensor[[T, H, HEAD_DIM], pl.BF16]], - kv: pl.Out[pl.Tensor[[T, HEAD_DIM], pl.BF16]], - qr: pl.Out[pl.Tensor[[T, Q_LORA], pl.INT8]], - qr_scale: pl.Out[pl.Tensor[[T, 1], pl.FP32]], + odd_select_t: pl.Tensor[[ROPE_HALF, ROPE_DIM], pl.BF16], + gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], + gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], + q: pl.Out[pl.Tensor[[T, H, HEAD_DIM], pl.BF16]], + kv: pl.Out[pl.Tensor[[T, HEAD_DIM], pl.BF16]], + qr: pl.Out[pl.Tensor[[T, Q_LORA], pl.INT8]], + qr_scale: pl.Out[pl.Tensor[[T, 1], pl.FP32]], ): q = qkv_proj_rope( x, @@ -395,15 +354,15 @@ def golden_qkv_proj_rope(tensors): """Torch reference: attn_norm fused, then Q/KV LoRA + RoPE (model.py 692, 495-504).""" import torch - x = tensors["x"].float() # [B, S, D] - norm_w = tensors["norm_w"].float() # [D] - wq_a = tensors["wq_a"].float() - wq_b = tensors["wq_b"] + x = tensors["x"].float() + norm_w = tensors["norm_w"].float() + wq_a = tensors["wq_a"].float() + wq_b = tensors["wq_b"] wq_b_scale = tensors["wq_b_scale"].float().view(-1) - wkv = tensors["wkv"].float() - rope_cos = tensors["rope_cos"].float() - rope_sin = tensors["rope_sin"].float() - gamma_cq = tensors["gamma_cq"].float() + wkv = tensors["wkv"].float() + rope_cos = tensors["rope_cos"].float() + rope_sin = tensors["rope_sin"].float() + gamma_cq = tensors["gamma_cq"].float() gamma_ckv = tensors["gamma_ckv"].float() def int8_quant_per_row(x): @@ -512,14 +471,14 @@ def init_gamma_ckv(): wq_b_bf16 = init_wq_b().to(torch.bfloat16) wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16) - wq_b_scale = wq_b_scale.view((H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE) + wq_b_scale = wq_b_scale.view(H * HEAD_DIM) return [ TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), TensorSpec("norm_w", [D], torch.float32, init_value=init_norm_w), TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a), TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8), - TensorSpec("wq_b_scale", [(H * HEAD_DIM) // Q_PROJ_OUT_TILE, Q_PROJ_OUT_TILE], torch.float32, init_value=lambda: wq_b_scale), + TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale), TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv), TensorSpec("rope_cos", [T, ROPE_DIM], torch.bfloat16, init_value=init_cos), TensorSpec("rope_sin", [T, ROPE_DIM], torch.bfloat16, init_value=init_sin), From 4b770033e4ad33ab6b5e4ba33b33a9e70074798e Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 14:43:26 +0800 Subject: [PATCH 4/8] Refactor: tile-size knobs and pl.pipeline conversion in qkv_proj_rope - Halve T_TILE (16 -> 8) to double attn_norm and qr_rms_norm_quant SPMD block counts; introduce QPROJ_T_TILE=16 for the qproj dequant T loop (cube innerRows alignment) and KV_RMS_T_TILE=16 for the new kv_rms_norm SPMD scope - Convert kv_rms_norm from a single pl.at scope into an 8-way SPMD scope over T (per-token reduction) - Convert remaining pl.range loops inside SPMD / pl.at to pl.pipeline(stage=2); keep pl.range on q_head_rms_nope and q_head_rope outer h_inner loops where pipelining would blow Vec UB - Inline ROPE_PAIR_TILE as 32 (no longer derived) --- models/deepseek/v4/qkv_proj_rope.py | 61 +++++++++++++++-------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index bfffde49..584e8084 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -30,7 +30,7 @@ # tiling ROPE_TILE = 64 -ROPE_PAIR_TILE = ROPE_TILE // 2 +ROPE_PAIR_TILE = 32 HEAD_TILE = 64 Q_PROJ_OUT_TILE = 128 Q_PROJ_TILE = 512 @@ -38,7 +38,9 @@ D_TILE = 128 KV_TILE = 32 QUANT_TILE = 32 -T_TILE = 16 +T_TILE = 8 +QPROJ_T_TILE = 16 +KV_RMS_T_TILE = 16 @pl.jit.inline @@ -66,13 +68,13 @@ def qkv_proj_rope( for tg_idx in pl.spmd(T // T_TILE, name_hint="attn_norm"): tg = tg_idx * T_TILE x_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) - for rms_db in pl.range(D // D_TILE): + for rms_db in pl.pipeline(D // D_TILE, stage=2): rms_d0 = rms_db * D_TILE rms_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, rms_d0 : rms_d0 + D_TILE], target_type=pl.FP32) x_sq_sum = pl.add(x_sq_sum, pl.reshape(pl.row_sum(pl.mul(rms_x_chunk, rms_x_chunk)), [1, T_TILE])) x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) x_inv_rms_t = pl.reshape(x_inv_rms, [T_TILE, 1]) - for apply_db in pl.range(D // D_TILE): + for apply_db in pl.pipeline(D // D_TILE, stage=2): apply_d0 = apply_db * D_TILE apply_x_chunk = pl.cast(x_flat[tg : tg + T_TILE, apply_d0 : apply_d0 + D_TILE], target_type=pl.FP32) norm_w_chunk = pl.reshape(norm_w[apply_d0 : apply_d0 + D_TILE], [1, D_TILE]) @@ -82,7 +84,7 @@ def qkv_proj_rope( qr_fp32 = pl.create_tensor([T, Q_LORA], dtype=pl.FP32) for qbg_idx in pl.spmd((Q_LORA // Q_LORA_TILE) // 2, name_hint="qr_proj_matmul"): qbg = qbg_idx * 2 - for q_inner in pl.range(2): + for q_inner in pl.pipeline(2, stage=2): q_a_col0 = (qbg + q_inner) * Q_LORA_TILE q_acc = pl.create_tensor([T, Q_LORA_TILE], dtype=pl.FP32) for db in pl.pipeline(0, D // D_TILE, stage=2): @@ -100,14 +102,14 @@ def qkv_proj_rope( tg = tg_idx * T_TILE qr_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) qr_tile_amax = pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) - for qr_rms_qb in pl.range(Q_LORA // Q_LORA_TILE): + for qr_rms_qb in pl.pipeline(Q_LORA // Q_LORA_TILE, stage=2): qr_rms_col0 = qr_rms_qb * Q_LORA_TILE qr_rms_chunk = qr_fp32[tg : tg + T_TILE, qr_rms_col0 : qr_rms_col0 + Q_LORA_TILE] qr_sq_sum = pl.add(qr_sq_sum, pl.reshape(pl.row_sum(pl.mul(qr_rms_chunk, qr_rms_chunk)), [1, T_TILE])) qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) qr_inv_rms_t = pl.reshape(qr_inv_rms, [T_TILE, 1]) - for qb in pl.range(Q_LORA // Q_LORA_TILE): + for qb in pl.pipeline(Q_LORA // Q_LORA_TILE, stage=2): qr_norm_col0 = qb * Q_LORA_TILE qr_norm_chunk = qr_fp32[tg : tg + T_TILE, qr_norm_col0 : qr_norm_col0 + Q_LORA_TILE] gamma_chunk = pl.reshape( @@ -125,7 +127,7 @@ def qkv_proj_rope( qr_tile_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T_TILE, 1]) qr_scale[tg : tg + T_TILE, :] = qr_tile_scale_dq - for qa in pl.range(0, Q_LORA, QUANT_TILE): + for qa in pl.pipeline(0, Q_LORA, QUANT_TILE, stage=2): qr_chunk = qr_fp32[tg : tg + T_TILE, qa : qa + QUANT_TILE] gamma_q_chunk = pl.reshape( pl.cast(gamma_cq[qa : qa + QUANT_TILE], target_type=pl.FP32), @@ -143,7 +145,7 @@ def qkv_proj_rope( for hg_idx in pl.spmd(((H * HEAD_DIM) // Q_PROJ_OUT_TILE) // 16, name_hint="qproj"): hg = hg_idx * 16 col_acc = pl.create_tensor([T, Q_PROJ_OUT_TILE], dtype=pl.INT32) - for h_inner in pl.range(16): + for h_inner in pl.pipeline(16, stage=2): for qb in pl.pipeline(0, Q_LORA // Q_PROJ_TILE, stage=2): qr_proj_col0 = qb * Q_PROJ_TILE qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_TILE] @@ -154,12 +156,12 @@ def qkv_proj_rope( col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) w_col0 = (hg + h_inner) * Q_PROJ_OUT_TILE w_scale = pl.reshape(wq_b_scale[w_col0 : w_col0 + Q_PROJ_OUT_TILE], [1, Q_PROJ_OUT_TILE]) - for tc in pl.range(0, T, T_TILE): - col_acc_t = col_acc[tc : tc + T_TILE, :] + for tc in pl.pipeline(0, T, QPROJ_T_TILE, stage=2): + col_acc_t = col_acc[tc : tc + QPROJ_T_TILE, :] col_fp32 = pl.cast(col_acc_t, target_type=pl.FP32, mode="none") - qr_scale_dq_t = qr_scale[tc : tc + T_TILE, :] + qr_scale_dq_t = qr_scale[tc : tc + QPROJ_T_TILE, :] col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq_t), w_scale) - q_proj_fp32[tc : tc + T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant + q_proj_fp32[tc : tc + QPROJ_T_TILE, (hg + h_inner) * Q_PROJ_OUT_TILE : (hg + h_inner) * Q_PROJ_OUT_TILE + Q_PROJ_OUT_TILE] = col_dequant # Per-head RMS, NOPE projection, and RoPE rotation, staged through # q_head_inv_rms_all and q_rope_pair_stage. @@ -171,7 +173,7 @@ def qkv_proj_rope( h = hg + h_inner h0 = h * HEAD_DIM q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for db in pl.range(HEAD_DIM // HEAD_TILE): + for db in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): d0 = h0 + db * HEAD_TILE q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_TILE] q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) @@ -179,7 +181,7 @@ def qkv_proj_rope( q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms - for nb in pl.range(NOPE_DIM // HEAD_TILE): + for nb in pl.pipeline(NOPE_DIM // HEAD_TILE, stage=2): n0 = nb * HEAD_TILE q_nope_chunk = q_proj_fp32[:, h0 + n0 : h0 + n0 + HEAD_TILE] q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) @@ -208,7 +210,7 @@ def qkv_proj_rope( q_rope_grp_fp32 = pl.create_tensor([H * T, ROPE_DIM], dtype=pl.FP32) for hg_idx in pl.spmd(H // 8, name_hint="q_rope_reassemble"): hg = hg_idx * 8 - for h_inner in pl.range(8): + for h_inner in pl.pipeline(8, stage=2): even_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, :ROPE_HALF] odd_chunk = q_rope_pair_stage[(hg + h_inner) * T : (hg + h_inner) * T + T, ROPE_HALF : ROPE_DIM] rot = pl.matmul( @@ -225,7 +227,7 @@ def qkv_proj_rope( for hg_idx in pl.spmd(H // 8, name_hint="q_rope_write"): hg = hg_idx * 8 - for h_inner in pl.range(8): + for h_inner in pl.pipeline(8, stage=2): rot_fp32 = q_rope_grp_fp32[(hg + h_inner) * T : (hg + h_inner) * T + T, :] q_flat[:, (hg + h_inner) * HEAD_DIM + NOPE_DIM : (hg + h_inner) * HEAD_DIM + NOPE_DIM + ROPE_DIM] = pl.cast(rot_fp32, target_type=pl.BF16, mode="rint") @@ -233,7 +235,7 @@ def qkv_proj_rope( kv_fp32 = pl.create_tensor([T, HEAD_DIM], dtype=pl.FP32) for kbg in pl.spmd((HEAD_DIM // KV_TILE) // 2, name_hint="kv_proj_matmul"): - for k_inner in pl.range(2): + for k_inner in pl.pipeline(2, stage=2): kv_acc = pl.create_tensor([T, KV_TILE], dtype=pl.FP32) kv_col0 = (kbg * 2 + k_inner) * KV_TILE for db in pl.pipeline(0, D // D_TILE, stage=2): @@ -247,25 +249,26 @@ def qkv_proj_rope( kv_fp32[:, kv_col0 : kv_col0 + KV_TILE] = kv_acc kv_inv_rms_tensor = pl.create_tensor([1, T], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rms_norm"): - kv_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for kb in pl.range(HEAD_DIM // KV_TILE): + for tg_idx in pl.spmd(T // KV_RMS_T_TILE, name_hint="kv_rms_norm"): + tg = tg_idx * KV_RMS_T_TILE + kv_sq_sum = pl.full([1, KV_RMS_T_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(HEAD_DIM // KV_TILE, stage=2): kv_sq_col0 = kb * KV_TILE - kv_chunk = kv_fp32[:, kv_sq_col0 : kv_sq_col0 + KV_TILE] - kv_sq_sum = pl.add(kv_sq_sum, pl.reshape(pl.row_sum(pl.mul(kv_chunk, kv_chunk)), [1, T])) + kv_chunk = kv_fp32[tg : tg + KV_RMS_T_TILE, kv_sq_col0 : kv_sq_col0 + KV_TILE] + kv_sq_sum = pl.add(kv_sq_sum, pl.reshape(pl.row_sum(pl.mul(kv_chunk, kv_chunk)), [1, KV_RMS_T_TILE])) kv_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(kv_sq_sum, 1.0 / HEAD_DIM), EPS))) - kv_inv_rms_t = pl.reshape(kv_inv_rms, [T, 1]) - kv_inv_rms_tensor[0:1, :] = kv_inv_rms + kv_inv_rms_t = pl.reshape(kv_inv_rms, [KV_RMS_T_TILE, 1]) + kv_inv_rms_tensor[0:1, tg : tg + KV_RMS_T_TILE] = kv_inv_rms - for nb in pl.range(NOPE_DIM // KV_TILE): + for nb in pl.pipeline(NOPE_DIM // KV_TILE, stage=2): n0 = nb * KV_TILE - kv_chunk = kv_fp32[:, n0 : n0 + KV_TILE] + kv_chunk = kv_fp32[tg : tg + KV_RMS_T_TILE, n0 : n0 + KV_TILE] gamma_kv_chunk = pl.reshape( pl.cast(gamma_ckv[n0 : n0 + KV_TILE], target_type=pl.FP32), [1, KV_TILE], ) kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_chunk, kv_inv_rms_t), gamma_kv_chunk) - kv[:, n0 : n0 + KV_TILE] = pl.cast(kv_normed, target_type=pl.BF16, mode="rint") + kv[tg : tg + KV_RMS_T_TILE, n0 : n0 + KV_TILE] = pl.cast(kv_normed, target_type=pl.BF16, mode="rint") kv_rot_even_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) kv_rot_odd_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) @@ -288,7 +291,7 @@ def qkv_proj_rope( kv_rope_full = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): - for rope_col in pl.range(0, ROPE_DIM, ROPE_TILE): + for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_TILE, stage=2): pair_col = rope_col // 2 kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_TILE] kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_TILE] From 25ad2024b11e736a92245e95b12c7fdca90b7448 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 14:47:15 +0800 Subject: [PATCH 5/8] Refactor: flatten wq_b_scale to [H * HEAD_DIM] in dsv4 callers Match the qkv_proj_rope signature change: drop the [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] 2D shape and the corresponding .view() at TensorSpec build time. Updated decode_attention_{csa,hca,swa} and decode_{csa,hca,swa}. Draft callers (*_draft.py) left untouched. --- models/deepseek/v4/decode_attention_csa.py | 7 +++---- models/deepseek/v4/decode_attention_hca.py | 7 +++---- models/deepseek/v4/decode_attention_swa.py | 7 +++---- models/deepseek/v4/decode_csa.py | 4 ++-- models/deepseek/v4/decode_hca.py | 4 ++-- models/deepseek/v4/decode_swa.py | 4 ++-- 6 files changed, 15 insertions(+), 18 deletions(-) diff --git a/models/deepseek/v4/decode_attention_csa.py b/models/deepseek/v4/decode_attention_csa.py index c6f19a85..635d4172 100644 --- a/models/deepseek/v4/decode_attention_csa.py +++ b/models/deepseek/v4/decode_attention_csa.py @@ -115,7 +115,7 @@ def attention_csa( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -365,7 +365,7 @@ def attention_csa_test_refresh( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -969,7 +969,6 @@ def init_wo_b(): wq_b_bf16 = init_wq_b().to(torch.bfloat16) wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16) - wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK) wo_b_bf16 = init_wo_b().to(torch.bfloat16) wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16) @@ -981,7 +980,7 @@ def init_wo_b(): TensorSpec("attn_norm_w", [D], torch.float32, init_value=lambda: shared_attn_norm_w.clone()), TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=lambda: shared_wq_a.clone()), TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8), - TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale), + TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale), TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv), TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=lambda: shared_gamma_cq.clone()), TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv), diff --git a/models/deepseek/v4/decode_attention_hca.py b/models/deepseek/v4/decode_attention_hca.py index 4541fe7a..81563815 100644 --- a/models/deepseek/v4/decode_attention_hca.py +++ b/models/deepseek/v4/decode_attention_hca.py @@ -84,7 +84,7 @@ def attention_hca( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -326,7 +326,7 @@ def attention_hca_test( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -664,7 +664,6 @@ def init_wo_b(): wq_b_bf16 = init_wq_b().to(torch.bfloat16) wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16) - wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK) wo_b_bf16 = init_wo_b().to(torch.bfloat16) wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16) @@ -676,7 +675,7 @@ def init_wo_b(): TensorSpec("attn_norm_w", [D], torch.float32, init_value=init_attn_norm_w), TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a), TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8), - TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale), + TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale), TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv), TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=init_gamma_cq), TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv), diff --git a/models/deepseek/v4/decode_attention_swa.py b/models/deepseek/v4/decode_attention_swa.py index b915d2fa..c9597381 100644 --- a/models/deepseek/v4/decode_attention_swa.py +++ b/models/deepseek/v4/decode_attention_swa.py @@ -77,7 +77,7 @@ def attention_swa( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -234,7 +234,7 @@ def attention_swa_test( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -484,7 +484,6 @@ def init_wo_b(): wq_b_bf16 = init_wq_b().to(torch.bfloat16) wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16) - wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK) wo_b_bf16 = init_wo_b().to(torch.bfloat16) wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16) @@ -496,7 +495,7 @@ def init_wo_b(): TensorSpec("attn_norm_w", [D], torch.float32, init_value=init_attn_norm_w), TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a), TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8), - TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale), + TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale), TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv), TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=init_gamma_cq), TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv), diff --git a/models/deepseek/v4/decode_csa.py b/models/deepseek/v4/decode_csa.py index 7372fea8..e7184770 100644 --- a/models/deepseek/v4/decode_csa.py +++ b/models/deepseek/v4/decode_csa.py @@ -98,7 +98,7 @@ def decode_csa( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -219,7 +219,7 @@ def decode_csa_test( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], diff --git a/models/deepseek/v4/decode_hca.py b/models/deepseek/v4/decode_hca.py index 2a478d00..accd357b 100644 --- a/models/deepseek/v4/decode_hca.py +++ b/models/deepseek/v4/decode_hca.py @@ -87,7 +87,7 @@ def decode_hca( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -191,7 +191,7 @@ def decode_hca_test( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], diff --git a/models/deepseek/v4/decode_swa.py b/models/deepseek/v4/decode_swa.py index 817d04f4..fac23d8f 100644 --- a/models/deepseek/v4/decode_swa.py +++ b/models/deepseek/v4/decode_swa.py @@ -74,7 +74,7 @@ def decode_swa( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], @@ -163,7 +163,7 @@ def decode_swa_test( attn_norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], gamma_cq: pl.Tensor[[Q_LORA], pl.BF16], gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16], From dc2e1e0fdf40c429367a5569941ac81ad77e9717 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 14:59:23 +0800 Subject: [PATCH 6/8] Refactor: rename qkv_proj_rope.py to decode_qkv_proj_rope.py Disambiguates the decode-side kernel from the new prefill_qkv_proj_rope module. Updates all `from qkv_proj_rope import ...` callers (3 decode attention files, 1 prefill draft, and prefill_qkv_proj_rope itself). Function names (qkv_proj_rope, golden_qkv_proj_rope) are unchanged. --- models/deepseek/v4/decode_attention_csa.py | 4 ++-- models/deepseek/v4/decode_attention_hca.py | 4 ++-- models/deepseek/v4/decode_attention_swa.py | 4 ++-- .../deepseek/v4/{qkv_proj_rope.py => decode_qkv_proj_rope.py} | 0 models/deepseek/v4/prefill_attention_swa_draft.py | 2 +- models/deepseek/v4/prefill_qkv_proj_rope.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) rename models/deepseek/v4/{qkv_proj_rope.py => decode_qkv_proj_rope.py} (100%) diff --git a/models/deepseek/v4/decode_attention_csa.py b/models/deepseek/v4/decode_attention_csa.py index 635d4172..73ac9ab9 100644 --- a/models/deepseek/v4/decode_attention_csa.py +++ b/models/deepseek/v4/decode_attention_csa.py @@ -34,7 +34,7 @@ from hc_post import hc_post from hc_pre import hc_pre from decode_indexer import indexer -from qkv_proj_rope import qkv_proj_rope +from decode_qkv_proj_rope import qkv_proj_rope from decode_sparse_attn import sparse_attn B = DECODE_BATCH @@ -465,7 +465,7 @@ def golden_attention_csa(tensors): from decode_compressor_ratio4 import golden_compressor from hc_pre import golden_hc_pre from decode_indexer import golden_indexer - from qkv_proj_rope import golden_qkv_proj_rope + from decode_qkv_proj_rope import golden_qkv_proj_rope def rms_norm(x, weight): x_fp32 = x.float() diff --git a/models/deepseek/v4/decode_attention_hca.py b/models/deepseek/v4/decode_attention_hca.py index 81563815..ae0923f5 100644 --- a/models/deepseek/v4/decode_attention_hca.py +++ b/models/deepseek/v4/decode_attention_hca.py @@ -19,7 +19,7 @@ from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS from hc_pre import hc_pre from hc_post import hc_post -from qkv_proj_rope import qkv_proj_rope +from decode_qkv_proj_rope import qkv_proj_rope from decode_compressor_ratio128 import compressor from decode_sparse_attn import sparse_attn @@ -383,7 +383,7 @@ def golden_attention_hca(tensors): import torch from hc_pre import golden_hc_pre - from qkv_proj_rope import golden_qkv_proj_rope + from decode_qkv_proj_rope import golden_qkv_proj_rope from decode_compressor_ratio128 import golden_compressor from decode_sparse_attn import golden_sparse_attn from hc_post import golden_hc_post diff --git a/models/deepseek/v4/decode_attention_swa.py b/models/deepseek/v4/decode_attention_swa.py index c9597381..b21ee615 100644 --- a/models/deepseek/v4/decode_attention_swa.py +++ b/models/deepseek/v4/decode_attention_swa.py @@ -20,7 +20,7 @@ from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS from hc_pre import hc_pre from hc_post import hc_post -from qkv_proj_rope import qkv_proj_rope +from decode_qkv_proj_rope import qkv_proj_rope from decode_sparse_attn import sparse_attn @@ -280,7 +280,7 @@ def golden_attention_swa(tensors): import torch from hc_pre import golden_hc_pre - from qkv_proj_rope import golden_qkv_proj_rope + from decode_qkv_proj_rope import golden_qkv_proj_rope from decode_sparse_attn import golden_sparse_attn from hc_post import golden_hc_post diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/decode_qkv_proj_rope.py similarity index 100% rename from models/deepseek/v4/qkv_proj_rope.py rename to models/deepseek/v4/decode_qkv_proj_rope.py diff --git a/models/deepseek/v4/prefill_attention_swa_draft.py b/models/deepseek/v4/prefill_attention_swa_draft.py index 70f63fe9..03935ed1 100644 --- a/models/deepseek/v4/prefill_attention_swa_draft.py +++ b/models/deepseek/v4/prefill_attention_swa_draft.py @@ -61,7 +61,7 @@ def golden_prefill_attention_swa(tensors): import torch from hc_pre import golden_hc_pre - from qkv_proj_rope import golden_qkv_proj_rope + from decode_qkv_proj_rope import golden_qkv_proj_rope from hc_post import golden_hc_post from prefill_sparse_attn_draft import golden_prefill_sparse_attn diff --git a/models/deepseek/v4/prefill_qkv_proj_rope.py b/models/deepseek/v4/prefill_qkv_proj_rope.py index 4bea98ea..b3090e33 100644 --- a/models/deepseek/v4/prefill_qkv_proj_rope.py +++ b/models/deepseek/v4/prefill_qkv_proj_rope.py @@ -15,7 +15,7 @@ import pypto.language as pl from config import FLASH as M, INT8_AMAX_EPS, INT8_SCALE_MAX, PREFILL_BATCH, PREFILL_SEQ -from qkv_proj_rope import build_tensor_specs as _build_qkv_tensor_specs +from decode_qkv_proj_rope import build_tensor_specs as _build_qkv_tensor_specs B = PREFILL_BATCH From 6ba8231effaa5d893412a0ea901a4d863f28e33f Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 15:10:19 +0800 Subject: [PATCH 7/8] Fix: align prefill_qkv_proj_rope wq_b_scale shape with decode The decode build_tensor_specs now produces wq_b_scale as 1D [H * HEAD_DIM]; prefill imports those specs but kept the old 2D [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK] signature, which made shapes disagree at jit time. Flatten the prefill signature too and slice + reshape at the dequant use site. --- models/deepseek/v4/prefill_qkv_proj_rope.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/models/deepseek/v4/prefill_qkv_proj_rope.py b/models/deepseek/v4/prefill_qkv_proj_rope.py index b3090e33..8c7c8507 100644 --- a/models/deepseek/v4/prefill_qkv_proj_rope.py +++ b/models/deepseek/v4/prefill_qkv_proj_rope.py @@ -89,7 +89,7 @@ def prefill_qkv_proj_rope_core( norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], freqs_cos: pl.Tensor[[MAX_SEQ_LEN, ROPE_DIM], pl.BF16], freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_DIM], pl.BF16], @@ -393,7 +393,8 @@ def prefill_qkv_proj_rope_core( :, ] col_fp32 = pl.cast(col_acc_chunk, target_type=pl.FP32, mode="none") - w_scale = wq_b_scale[hbg + h_inner : hbg + h_inner + 1, :] + w_col0 = (hbg + h_inner) * Q_PROJ_OUT_CHUNK + w_scale = pl.reshape(wq_b_scale[w_col0 : w_col0 + Q_PROJ_OUT_CHUNK], [1, Q_PROJ_OUT_CHUNK]) col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_tile), w_scale) q_proj_fp32[ :, @@ -504,7 +505,7 @@ def prefill_qkv_proj_rope( norm_w: pl.Tensor[[D], pl.FP32], wq_a: pl.Tensor[[D, Q_LORA], pl.BF16], wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8], - wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32], + wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32], wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16], freqs_cos: pl.Tensor[[MAX_SEQ_LEN, ROPE_DIM], pl.BF16], freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_DIM], pl.BF16], From 34d521548d6fe2ccf55a3b3e4ab13079a8820d74 Mon Sep 17 00:00:00 2001 From: zhangqi-chen Date: Tue, 26 May 2026 15:13:42 +0800 Subject: [PATCH 8/8] Chore: remove prefill_sparse_attn_draft and redirect importer Drop the obsolete _draft file. prefill_attention_swa_draft now imports golden_prefill_sparse_attn from the non-draft prefill_sparse_attn module. --- .../v4/prefill_attention_swa_draft.py | 2 +- .../deepseek/v4/prefill_sparse_attn_draft.py | 105 ------------------ 2 files changed, 1 insertion(+), 106 deletions(-) delete mode 100644 models/deepseek/v4/prefill_sparse_attn_draft.py diff --git a/models/deepseek/v4/prefill_attention_swa_draft.py b/models/deepseek/v4/prefill_attention_swa_draft.py index 03935ed1..bad4a9c2 100644 --- a/models/deepseek/v4/prefill_attention_swa_draft.py +++ b/models/deepseek/v4/prefill_attention_swa_draft.py @@ -63,7 +63,7 @@ def golden_prefill_attention_swa(tensors): from hc_pre import golden_hc_pre from decode_qkv_proj_rope import golden_qkv_proj_rope from hc_post import golden_hc_post - from prefill_sparse_attn_draft import golden_prefill_sparse_attn + from prefill_sparse_attn import golden_prefill_sparse_attn x_mixed = torch.zeros(B, S, D, dtype=torch.bfloat16) post_t = torch.zeros(B, S, HC_MULT) diff --git a/models/deepseek/v4/prefill_sparse_attn_draft.py b/models/deepseek/v4/prefill_sparse_attn_draft.py deleted file mode 100644 index 93ba7dea..00000000 --- a/models/deepseek/v4/prefill_sparse_attn_draft.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) PyPTO Contributors. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ----------------------------------------------------------------------------------------------------------- -# ruff: noqa: F401,F403,F405,F821 -"""DeepSeek-V4 prefill sparse_attn scaffold. - -Kernel body is intentionally empty; golden follows the torch reference for this stage. -""" - -import pypto.language as pl - -from decode_sparse_attn import * # noqa: F401,F403 -from decode_sparse_attn import build_tensor_specs as _build_tensor_specs -from decode_sparse_attn import _int8_quant_per_row - - -@pl.jit -def prefill_sparse_attn( - q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], - ori_kv: pl.Tensor[[ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - ori_block_table: pl.Tensor[[B, ORI_MAX_BLOCKS], pl.INT32], - cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], - cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], - attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B], pl.INT32], - freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], - freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], - even_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - odd_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], - wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], - wo_b_scale: pl.Tensor[[D], pl.FP32], - attn_out: pl.Out[pl.Tensor[[T, D], pl.BF16]], -): - # TODO: kernel implementation - return attn_out - - -def golden_prefill_sparse_attn(tensors): - """Non-distributed prefill sparse attention. - - Official prefill uses the full prompt KV cache with causal masking. This - golden mirrors that single-rank behavior with the existing standalone - sparse-attn tensor contract: ``ori_kv`` is treated as the full prompt cache - for each batch, and each query token attends to positions ``[0, s]``. - """ - import torch - - q = tensors["q"].float() - ori_kv = tensors["ori_kv"].float() - ori_block_table = tensors["ori_block_table"] - attn_sink = tensors["attn_sink"].float() - cos = tensors["freqs_cos"].float() - sin = tensors["freqs_sin"].float() - wo_a = tensors["wo_a"].float() - wo_b_i8 = tensors["wo_b"] - wo_b_scale = tensors["wo_b_scale"].float() - - o = torch.zeros(T, H, HEAD_DIM) - for t in range(T): - b = t // S - s = t % S - gathered = [] - for raw in range(s + 1): - blk_id = int(ori_block_table[b, raw // BLOCK_SIZE].item()) - intra = raw % BLOCK_SIZE - gathered.append(ori_kv[blk_id, intra, 0]) - kv_b = torch.stack(gathered, dim=0) - scores = (q[t] @ kv_b.T) * SOFTMAX_SCALE - score_max = scores.max(dim=-1, keepdim=True).values - exp_scores = torch.exp(scores - score_max) - oi_num = exp_scores @ kv_b - li = exp_scores.sum(dim=-1, keepdim=True) - denom = li + torch.exp(attn_sink.unsqueeze(-1) - score_max) - o[t] = oi_num / denom - - rope_pair = o[..., NOPE_DIM:].unflatten(-1, (-1, 2)) - rope_even = rope_pair[..., 0] - rope_odd = rope_pair[..., 1] - cos_half = cos[:, :HALF_ROPE].unsqueeze(1) - sin_half = sin[:, :HALF_ROPE].unsqueeze(1) - inv_even = (rope_even * cos_half + rope_odd * sin_half).to(torch.bfloat16).float() - inv_odd = (rope_odd * cos_half - rope_even * sin_half).to(torch.bfloat16).float() - o_rope = torch.stack([inv_even, inv_odd], dim=-1).flatten(-2) - o = torch.cat([o[..., :NOPE_DIM], o_rope], dim=-1).to(torch.bfloat16) - - o_model = o.float().view(B, S, O_GROUPS, O_GROUP_IN) - o_r = torch.einsum("bsgd,grd->bsgr", o_model, wo_a) - o_r = o_r.to(torch.bfloat16).float() - o_r_q = o_r.flatten(2).view(T, O_GROUPS * O_LORA) - o_r_i8, o_r_scale = _int8_quant_per_row(o_r_q) - acc = o_r_i8.to(torch.int32) @ wo_b_i8.to(torch.int32).T - out = acc.float() * o_r_scale * wo_b_scale.unsqueeze(0) - - tensors["attn_out"][:] = out.to(torch.bfloat16) - - -def build_tensor_specs(*args, **kwargs): - return _build_tensor_specs(*args, **kwargs)