Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,9 @@ def cp_p2p_fwd_flash_attn(
elif section == "upper-triangle":
max_seqlen_q_ = max_seqlen_q // 2
if section in ["lower-triangle", "upper-triangle"]:
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus:
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1

Expand Down Expand Up @@ -1189,9 +1189,9 @@ def cp_p2p_bwd_flash_attn(
):
"""Per-tile backward call of CP P2P with FlashAttention backend"""
dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]]
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus:
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not use_flash_attn_3:
Expand All @@ -1201,9 +1201,9 @@ def cp_p2p_bwd_flash_attn(
softmax_lse__ = softmax_lse
causal_ = False
if section == "diagonal":
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = (-1, 0)
elif fa_utils.v2_7_0_plus:
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = 0
causal_ = True
Expand All @@ -1225,6 +1225,10 @@ def cp_p2p_bwd_flash_attn(
dk=dk,
dv=dv,
)
if use_flash_attn_3:
fa_backward_kwargs["is_causal"] = causal_
else:
fa_backward_kwargs["causal"] = causal_
flash_attn_bwd(
dout_part,
q_part,
Expand All @@ -1233,7 +1237,6 @@ def cp_p2p_bwd_flash_attn(
out_part,
softmax_lse__,
*fa_backward_args_thd,
causal=causal_,
**fa_backward_kwargs,
)

Expand Down Expand Up @@ -1508,7 +1511,8 @@ def forward(
flash_attn_fwd = (
_flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment
)
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = 0 if causal else -1
else:
if qkv_format == "thd":
from transformer_engine.pytorch.attention.dot_product_attention.backends import (
Expand Down Expand Up @@ -2985,9 +2989,9 @@ def forward(
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv_,
)
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = window_size_per_step[i]
elif fa_utils.v2_7_0_plus:
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1]
fa_outputs = flash_attn_fwd(
Expand Down Expand Up @@ -3206,13 +3210,15 @@ def backward(ctx, dout, *_args):
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i]
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = window_size_per_step[i]
elif fa_utils.v2_7_0_plus:
elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
if ctx.use_flash_attn_3:
fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type
else:
fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type
flash_attn_bwd(
dout_,
q_,
Expand All @@ -3221,7 +3227,6 @@ def backward(ctx, dout, *_args):
out_,
softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
**fa_backward_kwargs,
)

Expand Down Expand Up @@ -3361,7 +3366,8 @@ def forward(
)

flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = window_size
fa_forward_kwargs["window_size_left"] = window_size[0]
fa_forward_kwargs["window_size_right"] = window_size[1]
else:
if qkv_format == "thd":
from transformer_engine.pytorch.attention.dot_product_attention.backends import (
Expand Down Expand Up @@ -3738,7 +3744,8 @@ def backward(ctx, dout, *_args):
flash_attn_bwd = (
_flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment
)
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
fa_backward_kwargs["window_size_right"] = ctx.window_size[1]
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
if qkv_format == "thd":
Expand Down Expand Up @@ -3821,6 +3828,10 @@ def backward(ctx, dout, *_args):
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state
fa_backward_kwargs["causal"] = causal
else:
fa_backward_kwargs["is_causal"] = causal

flash_attn_bwd(
dout,
q,
Expand All @@ -3829,7 +3840,6 @@ def backward(ctx, dout, *_args):
out,
softmax_lse,
*fa_backward_args_thd,
causal=causal,
**fa_backward_kwargs,
)

Expand Down