@@ -937,7 +937,7 @@ def cp_p2p_fwd_flash_attn(
937937 elif section == "upper-triangle" :
938938 max_seqlen_q_ = max_seqlen_q // 2
939939 if section in ["lower-triangle" , "upper-triangle" ]:
940- if ( fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus ) :
940+ if fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus :
941941 fa_forward_kwargs ["window_size" ] = (- 1 , - 1 )
942942 elif use_flash_attn_3 or fa_utils .v2_7_0_plus :
943943 fa_forward_kwargs ["window_size_left" ] = - 1
@@ -1189,7 +1189,7 @@ def cp_p2p_bwd_flash_attn(
11891189):
11901190 """Per-tile backward call of CP P2P with FlashAttention backend"""
11911191 dq , dk , dv = [torch .empty_like (x ) for x in [q_part , k_part , v_part ]]
1192- if ( fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus ) :
1192+ if fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus :
11931193 fa_backward_kwargs ["window_size" ] = (- 1 , - 1 )
11941194 elif use_flash_attn_3 or fa_utils .v2_7_0_plus :
11951195 fa_backward_kwargs ["window_size_left" ] = - 1
@@ -1201,7 +1201,7 @@ def cp_p2p_bwd_flash_attn(
12011201 softmax_lse__ = softmax_lse
12021202 causal_ = False
12031203 if section == "diagonal" :
1204- if ( fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus ) :
1204+ if fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus :
12051205 fa_backward_kwargs ["window_size" ] = (- 1 , 0 )
12061206 elif use_flash_attn_3 or fa_utils .v2_7_0_plus :
12071207 fa_backward_kwargs ["window_size_left" ] = - 1
@@ -2985,7 +2985,7 @@ def forward(
29852985 max_seqlen_q = max_seqlen_q ,
29862986 max_seqlen_kv = max_seqlen_kv_ ,
29872987 )
2988- if ( fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus ) :
2988+ if fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus :
29892989 fa_forward_kwargs ["window_size" ] = window_size_per_step [i ]
29902990 elif use_flash_attn_3 or fa_utils .v2_7_0_plus :
29912991 fa_forward_kwargs ["window_size_left" ] = window_size_per_step [i ][0 ]
@@ -3206,9 +3206,7 @@ def backward(ctx, dout, *_args):
32063206 )
32073207 if not ctx .use_flash_attn_3 :
32083208 fa_backward_kwargs ["rng_state" ] = rng_states [i ]
3209- if (
3210- fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus
3211- ):
3209+ if fa_utils .v2_3_plus and not fa_utils .v2_7_0_plus :
32123210 fa_backward_kwargs ["window_size" ] = window_size_per_step [i ]
32133211 elif ctx .use_flash_attn_3 or fa_utils .v2_7_0_plus :
32143212 fa_backward_kwargs ["window_size_left" ] = window_size_per_step [i ][0 ]
0 commit comments