Skip to content
Merged
11 changes: 5 additions & 6 deletions models/deepseek/v4/decode_attention_csa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -34,7 +34,7 @@
from hc_post import hc_post
from hc_pre import hc_pre
from decode_indexer import indexer
from qkv_proj_rope import qkv_proj_rope
from decode_qkv_proj_rope import qkv_proj_rope
from decode_sparse_attn import sparse_attn

B = DECODE_BATCH
Expand Down Expand Up @@ -115,7 +115,7 @@
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -365,7 +365,7 @@
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -465,7 +465,7 @@
from decode_compressor_ratio4 import golden_compressor
from hc_pre import golden_hc_pre
from decode_indexer import golden_indexer
from qkv_proj_rope import golden_qkv_proj_rope
from decode_qkv_proj_rope import golden_qkv_proj_rope

def rms_norm(x, weight):
x_fp32 = x.float()
Expand Down Expand Up @@ -969,7 +969,6 @@

wq_b_bf16 = init_wq_b().to(torch.bfloat16)
wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16)
wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK)
wo_b_bf16 = init_wo_b().to(torch.bfloat16)
wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16)

Expand All @@ -981,7 +980,7 @@
TensorSpec("attn_norm_w", [D], torch.float32, init_value=lambda: shared_attn_norm_w.clone()),
TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=lambda: shared_wq_a.clone()),
TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8),
TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv),
TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=lambda: shared_gamma_cq.clone()),
TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv),
Expand Down
11 changes: 5 additions & 6 deletions models/deepseek/v4/decode_attention_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS
from hc_pre import hc_pre
from hc_post import hc_post
from qkv_proj_rope import qkv_proj_rope
from decode_qkv_proj_rope import qkv_proj_rope
from decode_compressor_ratio128 import compressor
from decode_sparse_attn import sparse_attn

Expand Down Expand Up @@ -84,7 +84,7 @@ def attention_hca(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -326,7 +326,7 @@ def attention_hca_test(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -383,7 +383,7 @@ def golden_attention_hca(tensors):
import torch

from hc_pre import golden_hc_pre
from qkv_proj_rope import golden_qkv_proj_rope
from decode_qkv_proj_rope import golden_qkv_proj_rope
from decode_compressor_ratio128 import golden_compressor
from decode_sparse_attn import golden_sparse_attn
from hc_post import golden_hc_post
Expand Down Expand Up @@ -664,7 +664,6 @@ def init_wo_b():

wq_b_bf16 = init_wq_b().to(torch.bfloat16)
wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16)
wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK)
wo_b_bf16 = init_wo_b().to(torch.bfloat16)
wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16)

Expand All @@ -676,7 +675,7 @@ def init_wo_b():
TensorSpec("attn_norm_w", [D], torch.float32, init_value=init_attn_norm_w),
TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a),
TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8),
TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv),
TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=init_gamma_cq),
TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv),
Expand Down
11 changes: 5 additions & 6 deletions models/deepseek/v4/decode_attention_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS
from hc_pre import hc_pre
from hc_post import hc_post
from qkv_proj_rope import qkv_proj_rope
from decode_qkv_proj_rope import qkv_proj_rope
from decode_sparse_attn import sparse_attn


Expand Down Expand Up @@ -77,7 +77,7 @@ def attention_swa(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -234,7 +234,7 @@ def attention_swa_test(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -280,7 +280,7 @@ def golden_attention_swa(tensors):
import torch

from hc_pre import golden_hc_pre
from qkv_proj_rope import golden_qkv_proj_rope
from decode_qkv_proj_rope import golden_qkv_proj_rope
from decode_sparse_attn import golden_sparse_attn
from hc_post import golden_hc_post

Expand Down Expand Up @@ -484,7 +484,6 @@ def init_wo_b():

wq_b_bf16 = init_wq_b().to(torch.bfloat16)
wq_b_i8, wq_b_scale = quant_w_per_output_channel(wq_b_bf16)
wq_b_scale = wq_b_scale.view(Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK)
wo_b_bf16 = init_wo_b().to(torch.bfloat16)
wo_b_i8, wo_b_scale = quant_w_per_row(wo_b_bf16)

Expand All @@ -496,7 +495,7 @@ def init_wo_b():
TensorSpec("attn_norm_w", [D], torch.float32, init_value=init_attn_norm_w),
TensorSpec("wq_a", [D, Q_LORA], torch.bfloat16, init_value=init_wq_a),
TensorSpec("wq_b", [Q_LORA, H * HEAD_DIM], torch.int8, init_value=lambda: wq_b_i8),
TensorSpec("wq_b_scale", [Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wq_b_scale", [H * HEAD_DIM], torch.float32, init_value=lambda: wq_b_scale),
TensorSpec("wkv", [D, HEAD_DIM], torch.bfloat16, init_value=init_wkv),
TensorSpec("gamma_cq", [Q_LORA], torch.bfloat16, init_value=init_gamma_cq),
TensorSpec("gamma_ckv", [HEAD_DIM], torch.bfloat16, init_value=init_gamma_ckv),
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/decode_csa.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def decode_csa(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -219,7 +219,7 @@ def decode_csa_test(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/decode_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def decode_hca(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down Expand Up @@ -191,7 +191,7 @@ def decode_hca_test(
attn_norm_w: pl.Tensor[[D], pl.FP32],
wq_a: pl.Tensor[[D, Q_LORA], pl.BF16],
wq_b: pl.Tensor[[Q_LORA, H * HEAD_DIM], pl.INT8],
wq_b_scale: pl.Tensor[[Q_PROJ_HEAD_BLOCKS, Q_PROJ_OUT_CHUNK], pl.FP32],
wq_b_scale: pl.Tensor[[H * HEAD_DIM], pl.FP32],
wkv: pl.Tensor[[D, HEAD_DIM], pl.BF16],
gamma_cq: pl.Tensor[[Q_LORA], pl.BF16],
gamma_ckv: pl.Tensor[[HEAD_DIM], pl.BF16],
Expand Down
Loading
Loading