Skip to content

Commit 2d0d276

Browse files
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes (NVIDIA#2537)
* Plumbing correct bias dims from TE to cudnn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make changes for cp bias code Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add dbias and dbias_ to run_dpa_with_cp test Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix: Use output_dBias instead of input_dBias to extract the shape Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add guards for bias/bias_/dbias/dbias_ being None Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for bias shape 111s in addition to the original 1hss, 11ss, b1ss and bhss Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for dbias calculation and variant packing for the dbias shapes b1ss, bhss, 11ss in addition to the already supported 1hss Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for 111s bias shape in DPA Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Allow fused attn for dbias calculation for 11ss, b1ss, bhss. Disable fused attn if dbias calculation for 111s is required, else enable Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable requires_grad for bias for shape 111s in tests Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable bias grad / training flag for 111s bias in the non-CP attn tests. Add bias shape 111s to test_dpa_bias_shapes Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Fix to correctly create the bias shape tensor instead of the hard coded shape. Fix the comparison logic shapes for bias/dbias Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add fused attn cp test cases for all supported bias shapes Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: switch to elif for bias grad conditional Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add CP support for bias/dbias shape 111s Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for is_training in CP attn tests Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Fix incorrect comment Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Fix incorrect comment and assert string Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Create the dbias graph tensor only if it is a cuDNN supported bias shape Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Fix the dim that is being compared for the two cp chunks in the test Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Reinstate the original test for right side swa Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f122b07 commit 2d0d276

File tree

10 files changed

+501
-233
lines changed

10 files changed

+501
-233
lines changed

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 288 additions & 125 deletions
Large diffs are not rendered by default.

tests/pytorch/attention/test_attention.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,16 @@ def test_dot_product_attention(
162162
)
163163

164164
# Get backends
165+
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
166+
# For all other shapes test fwd+bwd
165167
is_training = True
168+
# TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s.
169+
if config.bias_shape == "111s":
170+
is_training = False
171+
logging.info(
172+
"Setting is_training to False as cuDNN does not support dbias for"
173+
f" {config.bias_shape=} "
174+
)
166175
available_backends, _, fused_attn_backends = get_available_attention_backends(
167176
config,
168177
qkv_dtype=dtype,
@@ -636,7 +645,8 @@ def test_dpa_bias(dtype, model_configs, model):
636645
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
637646
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
638647
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
639-
"bias_1_4": ModelConfig(
648+
"bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"),
649+
"bias_1_5": ModelConfig(
640650
4,
641651
2048,
642652
24,
@@ -646,7 +656,7 @@ def test_dpa_bias(dtype, model_configs, model):
646656
bias_shape="1hss",
647657
alibi_type="custom",
648658
),
649-
"bias_1_5": ModelConfig(
659+
"bias_1_6": ModelConfig(
650660
2,
651661
2048,
652662
24,
@@ -1143,10 +1153,16 @@ def _run_dot_product_attention(
11431153
bias = None
11441154
if config.attn_bias_type == "post_scale_bias":
11451155
shape = "_".join(config.bias_shape)
1156+
# For 1hss, 11ss, b1ss, bhss
1157+
shape_cache = shape
11461158
shape = shape.replace("_s_s", "_sq_skv")
1159+
# For 111s
1160+
if shape == shape_cache:
1161+
shape = shape.replace("_1_s", "_1_skv")
11471162
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
11481163
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
1149-
if config.bias_shape != "1hss":
1164+
# For 111s, dbias calculation is not supported as of cuDNN 9.18
1165+
if config.bias_shape == "111s":
11501166
bias.requires_grad = False
11511167

11521168
# Create RNG

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
147147
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
148148
), # MHA
149149
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
150-
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
150+
"cp_1_4": ModelConfig(
151+
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"
152+
), # MHA
153+
"cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
151154
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
152155
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
153156
"cp_2_2": ModelConfig(
@@ -160,9 +163,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
160163
attn_bias_type="post_scale_bias",
161164
), # GQA
162165
"cp_2_3": ModelConfig(
163-
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
166+
2,
167+
4096,
168+
12,
169+
128,
170+
num_gqa_groups=2,
171+
attn_mask_type="causal",
172+
attn_bias_type="post_scale_bias",
173+
bias_shape="11ss",
164174
), # GQA
165175
"cp_2_4": ModelConfig(
176+
2,
177+
4096,
178+
12,
179+
128,
180+
num_gqa_groups=2,
181+
attn_mask_type="causal",
182+
attn_bias_type="post_scale_bias",
183+
bias_shape="111s",
184+
return_max_logit=True,
185+
), # GQA
186+
"cp_2_5": ModelConfig(
187+
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
188+
), # GQA
189+
"cp_2_6": ModelConfig(
166190
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
167191
), # GQA
168192
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
@@ -171,6 +195,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
171195
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
172196
), # MLA
173197
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
198+
"cp_3_4": ModelConfig(
199+
2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss", head_dim_v=64
200+
), # MLA
174201
"cp_4_0": ModelConfig(
175202
2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
176203
), # GQA
@@ -191,10 +218,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
191218
"cp_1_0",
192219
"cp_1_1",
193220
"cp_1_4",
221+
"cp_1_5",
194222
"cp_2_0",
195223
"cp_2_2",
224+
"cp_2_3",
196225
"cp_2_4",
197226
"cp_3_2",
227+
"cp_3_4",
198228
"cp_4_2",
199229
]
200230
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
@@ -324,12 +354,15 @@ def test_cp_with_fused_attention(
324354
Float8CurrentScaling(fp8_dpa=True),
325355
DelayedScaling(fp8_dpa=True),
326356
]
357+
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
358+
is_training = False if config.bias_shape == "111s" else True
327359
available_backends, _, fused_attn_backends = get_available_attention_backends(
328360
config,
329361
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
330362
qkv_layout="_".join([qkv_format] * 3),
331363
fp8=fp8,
332364
fp8_meta=fp8_meta,
365+
is_training=is_training,
333366
)
334367
_, fused_attn_supported, _ = available_backends
335368
if not fused_attn_supported:
@@ -348,6 +381,7 @@ def test_cp_with_fused_attention(
348381
fp8_mha=fp8_mha,
349382
scaling_mode=scaling_mode,
350383
f16_O=f16_O,
384+
is_training=is_training,
351385
log_level=pytest_logging_level,
352386
),
353387
check=True,

tests/pytorch/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def get_available_attention_backends(
271271
os.environ["NVTE_FUSED_ATTN"] = "1"
272272
os.environ["NVTE_UNFUSED_ATTN"] = "1"
273273
_attention_backends["backend_selection_requires_update"] = True
274-
275274
alibi_slopes_shape = None
276275
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
277276
if config.bias_shape == "1hss":
@@ -289,7 +288,9 @@ def get_available_attention_backends(
289288
and config.head_dim_qk <= 128
290289
and config.head_dim_v <= 128
291290
):
292-
core_attention_bias_requires_grad = True
291+
# TODO(KshitijLakhani): Remove this guard when cuDNN starts support dbias calculation for bias shape 111s
292+
if core_attention_bias_shape != "111s":
293+
core_attention_bias_requires_grad = True
293294

294295
fused_attn_backends = []
295296
available_backends = None

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
5252
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
5353
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
5454
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
55-
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
56-
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
57-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
58-
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
59-
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
60-
void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
61-
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
55+
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv,
56+
bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability,
57+
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
58+
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
59+
bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
60+
void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO,
61+
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
62+
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
6263
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
6364
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
6465
using namespace transformer_engine;
@@ -121,6 +122,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
121122
max_pages_per_seq_v,
122123
bias_b,
123124
bias_h,
125+
bias_sq,
126+
bias_skv,
124127
scaling_factor,
125128
is_training,
126129
dropout_probability,
@@ -269,10 +272,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
269272
sdpa_options.set_alibi_mask(is_alibi);
270273

271274
if (is_bias) {
272-
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
273-
.set_name("bias")
274-
.set_dim({bias_b, bias_h, s_q, s_kv})
275-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
275+
bias = mha_graph->tensor(
276+
fe::graph::Tensor_attributes()
277+
.set_name("bias")
278+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
279+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
276280
sdpa_options.set_bias(bias);
277281
}
278282

@@ -548,16 +552,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
548552
void fused_attn_arbitrary_seqlen_bwd_impl(
549553
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
550554
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
551-
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
552-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
553-
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
554-
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
555-
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
556-
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
557-
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
558-
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
559-
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
560-
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
555+
int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability,
556+
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
557+
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
558+
bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
559+
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
560+
void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO,
561+
void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
562+
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
563+
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
564+
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
561565
using namespace transformer_engine;
562566

563567
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
@@ -622,6 +626,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
622626
0,
623627
bias_b,
624628
bias_h,
629+
bias_sq,
630+
bias_skv,
625631
scaling_factor,
626632
true,
627633
dropout_probability,
@@ -811,19 +817,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
811817
sdpa_backward_options.set_alibi_mask(is_alibi);
812818

813819
if (is_bias) {
814-
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
815-
.set_name("bias")
816-
.set_dim({bias_b, bias_h, s_q, s_kv})
817-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
818-
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
819-
.set_name("dBias")
820-
.set_dim({bias_b, bias_h, s_q, s_kv})
821-
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
820+
bias = mha_graph->tensor(
821+
fe::graph::Tensor_attributes()
822+
.set_name("bias")
823+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
824+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
822825
sdpa_backward_options.set_bias(bias);
823-
// shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s]
824-
// are not supported for dbias calculation but they are
825-
// supported for forward bias calculation
826-
if ((bias_b == 1) && (bias_h == h)) {
826+
// bias shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s], [1, h, s, s] are supported for dbias calculation
827+
// bias shape [1, 1, 1, s] is not supported for dbias calculation as of cuDNN 9.18
828+
if (!((bias_b == 1) && (bias_h == 1) && (bias_sq == 1))) {
829+
dBias = mha_graph->tensor(
830+
fe::graph::Tensor_attributes()
831+
.set_name("dBias")
832+
.set_dim({bias_b, bias_h, bias_sq, bias_skv})
833+
.set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1}));
827834
sdpa_backward_options.set_dbias(dBias);
828835
}
829836
}
@@ -974,10 +981,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
974981

975982
if (is_bias) {
976983
variant_pack[bias] = devPtrBias;
977-
if ((bias_b == 1) && (bias_h == h)) {
984+
if (dBias != nullptr) {
978985
variant_pack[dBias] = devPtrdBias;
979-
} else {
980-
variant_pack[dBias] = nullptr;
981986
}
982987
}
983988

@@ -1083,10 +1088,14 @@ void fused_attn_arbitrary_seqlen_fwd(
10831088
void *devPtrBias = nullptr;
10841089
size_t bias_b = 0;
10851090
size_t bias_h = 0;
1091+
size_t bias_sq = 0;
1092+
size_t bias_skv = 0;
10861093
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
10871094
devPtrBias = input_Bias->data.dptr;
10881095
bias_b = input_Bias->data.shape[0];
10891096
bias_h = input_Bias->data.shape[1];
1097+
bias_sq = input_Bias->data.shape[2];
1098+
bias_skv = input_Bias->data.shape[3];
10901099
}
10911100
void *devPtrSoftmaxOffset = nullptr;
10921101
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
@@ -1152,7 +1161,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11521161
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
11531162
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
11541163
output_bias->data.dptr = nullptr;
1155-
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
1164+
output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv};
11561165
output_bias->data.dtype = QKV_type;
11571166
}
11581167

@@ -1197,10 +1206,10 @@ void fused_attn_arbitrary_seqlen_fwd(
11971206
fused_attn_arbitrary_seqlen_fwd_impl(
11981207
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
11991208
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
1200-
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
1201-
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
1202-
window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
1203-
devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
1209+
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv,
1210+
is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
1211+
softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK,
1212+
devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
12041213
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
12051214
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
12061215
&workspace_size, stream, handle);
@@ -1244,11 +1253,15 @@ void fused_attn_arbitrary_seqlen_bwd(
12441253
void *devPtrdBias = nullptr;
12451254
size_t bias_b = 0;
12461255
size_t bias_h = 0;
1256+
size_t bias_sq = 0;
1257+
size_t bias_skv = 0;
12471258
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
12481259
devPtrBias = input_Bias->data.dptr;
12491260
devPtrdBias = output_dBias->data.dptr;
12501261
bias_b = output_dBias->data.shape[0];
12511262
bias_h = output_dBias->data.shape[1];
1263+
bias_sq = output_dBias->data.shape[2];
1264+
bias_skv = output_dBias->data.shape[3];
12521265
}
12531266

12541267
size_t max_batch_size = 0;
@@ -1291,11 +1304,11 @@ void fused_attn_arbitrary_seqlen_bwd(
12911304

12921305
fused_attn_arbitrary_seqlen_bwd_impl(
12931306
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
1294-
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
1295-
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
1296-
bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
1297-
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
1298-
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
1307+
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale,
1308+
p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
1309+
window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO,
1310+
devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
1311+
devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
12991312
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
13001313
workspace->data.dptr, &workspace_size, stream, handle);
13011314

0 commit comments

Comments
 (0)