Skip to content

Commit f9752ca

Browse files
committed
Fix Flash Attention 3 backward API parameter naming
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
1 parent b200742 commit f9752ca

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,10 @@ def cp_p2p_bwd_flash_attn(
12251225
dk=dk,
12261226
dv=dv,
12271227
)
1228+
if use_flash_attn_3:
1229+
fa_backward_kwargs["is_causal"] = causal_
1230+
else:
1231+
fa_backward_kwargs["causal"] = causal_
12281232
flash_attn_bwd(
12291233
dout_part,
12301234
q_part,
@@ -3213,6 +3217,10 @@ def backward(ctx, dout, *_args):
32133217
elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus:
32143218
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
32153219
fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
3220+
if ctx.use_flash_attn_3:
3221+
fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type
3222+
else:
3223+
fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type
32163224
flash_attn_bwd(
32173225
dout_,
32183226
q_,
@@ -3221,7 +3229,6 @@ def backward(ctx, dout, *_args):
32213229
out_,
32223230
softmax_lse_per_step[i],
32233231
*fa_backward_args_thd,
3224-
causal="causal" in ctx.attn_mask_type,
32253232
**fa_backward_kwargs,
32263233
)
32273234

@@ -3821,8 +3828,11 @@ def backward(ctx, dout, *_args):
38213828
dk=dk,
38223829
dv=dv,
38233830
)
3824-
if not ctx.use_flash_attn_3:
3831+
if ctx.use_flash_attn_3:
3832+
fa_backward_kwargs["is_causal"] = causal
3833+
else:
38253834
fa_backward_kwargs["rng_state"] = rng_state
3835+
fa_backward_kwargs["causal"] = causal
38263836
flash_attn_bwd(
38273837
dout,
38283838
q,
@@ -3831,7 +3841,6 @@ def backward(ctx, dout, *_args):
38313841
out,
38323842
softmax_lse,
38333843
*fa_backward_args_thd,
3834-
causal=causal,
38353844
**fa_backward_kwargs,
38363845
)
38373846

0 commit comments

Comments
 (0)