diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd6b626b64..10ba99595b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -937,9 +937,9 @@ def cp_p2p_fwd_flash_attn( elif section == "upper-triangle": max_seqlen_q_ = max_seqlen_q // 2 if section in ["lower-triangle", "upper-triangle"]: - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 @@ -1189,9 +1189,9 @@ def cp_p2p_bwd_flash_attn( ): """Per-tile backward call of CP P2P with FlashAttention backend""" dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not use_flash_attn_3: @@ -1201,9 +1201,9 @@ def cp_p2p_bwd_flash_attn( softmax_lse__ = softmax_lse causal_ = False if section == "diagonal": - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 causal_ = True @@ -1225,6 +1225,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) + if use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal_ + else: + fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -1233,7 +1237,6 @@ def cp_p2p_bwd_flash_attn( out_part, softmax_lse__, *fa_backward_args_thd, - causal=causal_, **fa_backward_kwargs, ) @@ -1508,7 +1511,8 @@ def forward( flash_attn_fwd = ( _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -2985,9 +2989,9 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, ) - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3206,13 +3210,15 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + else: + fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3221,7 +3227,6 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, - causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3361,7 +3366,8 @@ def forward( ) flash_attn_fwd = _flash_attn_fwd_v3 - fa_forward_kwargs["window_size"] = window_size + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -3738,7 +3744,8 @@ def backward(ctx, dout, *_args): flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: if qkv_format == "thd": @@ -3821,6 +3828,10 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state + fa_backward_kwargs["causal"] = causal + else: + fa_backward_kwargs["is_causal"] = causal + flash_attn_bwd( dout, q, @@ -3829,7 +3840,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal=causal, **fa_backward_kwargs, )