Skip to content

Commit a245229

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b200742 commit a245229

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ def cp_p2p_fwd_flash_attn(
937937
elif section == "upper-triangle":
938938
max_seqlen_q_ = max_seqlen_q // 2
939939
if section in ["lower-triangle", "upper-triangle"]:
940-
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
940+
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
941941
fa_forward_kwargs["window_size"] = (-1, -1)
942942
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
943943
fa_forward_kwargs["window_size_left"] = -1
@@ -1189,7 +1189,7 @@ def cp_p2p_bwd_flash_attn(
11891189
):
11901190
"""Per-tile backward call of CP P2P with FlashAttention backend"""
11911191
dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]]
1192-
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
1192+
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
11931193
fa_backward_kwargs["window_size"] = (-1, -1)
11941194
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
11951195
fa_backward_kwargs["window_size_left"] = -1
@@ -1201,7 +1201,7 @@ def cp_p2p_bwd_flash_attn(
12011201
softmax_lse__ = softmax_lse
12021202
causal_ = False
12031203
if section == "diagonal":
1204-
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
1204+
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
12051205
fa_backward_kwargs["window_size"] = (-1, 0)
12061206
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
12071207
fa_backward_kwargs["window_size_left"] = -1
@@ -2985,7 +2985,7 @@ def forward(
29852985
max_seqlen_q=max_seqlen_q,
29862986
max_seqlen_kv=max_seqlen_kv_,
29872987
)
2988-
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
2988+
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
29892989
fa_forward_kwargs["window_size"] = window_size_per_step[i]
29902990
elif use_flash_attn_3 or fa_utils.v2_7_0_plus:
29912991
fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
@@ -3206,9 +3206,7 @@ def backward(ctx, dout, *_args):
32063206
)
32073207
if not ctx.use_flash_attn_3:
32083208
fa_backward_kwargs["rng_state"] = rng_states[i]
3209-
if (
3210-
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
3211-
):
3209+
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
32123210
fa_backward_kwargs["window_size"] = window_size_per_step[i]
32133211
elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus:
32143212
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]

0 commit comments

Comments
 (0)