From b200742d068d59c9bf5050190574626a84dcc60c Mon Sep 17 00:00:00 2001 From: meichaoyang001 Date: Wed, 25 Feb 2026 15:07:16 +0800 Subject: [PATCH 1/6] Fix Flash Attention 3 API compatibility for window size parameters Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes. - Update function signature in flash_attn_interface - Maintain backward compatibility where possible - Ensure consistency with Flash Attention v2 implementation Signed-off-by: Chaoyang Mei <1192554423@qq.com> Signed-off-by: meichaoyang001 --- .../dot_product_attention/context_parallel.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd6b626b64..8799286323 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -937,9 +937,9 @@ def cp_p2p_fwd_flash_attn( elif section == "upper-triangle": max_seqlen_q_ = max_seqlen_q // 2 if section in ["lower-triangle", "upper-triangle"]: - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 @@ -1189,9 +1189,9 @@ def cp_p2p_bwd_flash_attn( ): """Per-tile backward call of CP P2P with FlashAttention backend""" dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = -1 if not use_flash_attn_3: @@ -1201,9 +1201,9 @@ def cp_p2p_bwd_flash_attn( softmax_lse__ = softmax_lse causal_ = False if section == "diagonal": - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_right"] = 0 causal_ = True @@ -1233,7 +1233,6 @@ def cp_p2p_bwd_flash_attn( out_part, softmax_lse__, *fa_backward_args_thd, - causal=causal_, **fa_backward_kwargs, ) @@ -1508,7 +1507,8 @@ def forward( flash_attn_fwd = ( _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -2985,9 +2985,9 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, ) - if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_forward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( @@ -3206,11 +3206,11 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if ctx.use_flash_attn_3 or ( + if ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus ): fa_backward_kwargs["window_size"] = window_size_per_step[i] - elif fa_utils.v2_7_0_plus: + elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( @@ -3361,7 +3361,8 @@ def forward( ) flash_attn_fwd = _flash_attn_fwd_v3 - fa_forward_kwargs["window_size"] = window_size + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] else: if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( @@ -3738,7 +3739,8 @@ def backward(ctx, dout, *_args): flash_attn_bwd = ( _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment ) - fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: if qkv_format == "thd": From f9752cae074b261a600b5c68e95e366db56a80e5 Mon Sep 17 00:00:00 2001 From: meichaoyang001 Date: Wed, 25 Feb 2026 15:52:39 +0800 Subject: [PATCH 2/6] Fix Flash Attention 3 backward API parameter naming Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 --- .../dot_product_attention/context_parallel.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 8799286323..c2880f7836 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1225,6 +1225,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) + if use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal_ + else: + fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -3213,6 +3217,10 @@ def backward(ctx, dout, *_args): elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + else: + fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3221,7 +3229,6 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, - causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3821,8 +3828,11 @@ def backward(ctx, dout, *_args): dk=dk, dv=dv, ) - if not ctx.use_flash_attn_3: + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal + else: fa_backward_kwargs["rng_state"] = rng_state + fa_backward_kwargs["causal"] = causal flash_attn_bwd( dout, q, @@ -3831,7 +3841,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal=causal, **fa_backward_kwargs, ) From 35a6b5ce47aa06f41161970bfc05ad5d259ed517 Mon Sep 17 00:00:00 2001 From: meichaoyang001 Date: Wed, 25 Feb 2026 15:52:39 +0800 Subject: [PATCH 3/6] Fix Flash Attention 3 backward API parameter naming Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 --- .../dot_product_attention/context_parallel.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 8799286323..c2880f7836 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1225,6 +1225,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) + if use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal_ + else: + fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -3213,6 +3217,10 @@ def backward(ctx, dout, *_args): elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + else: + fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3221,7 +3229,6 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, - causal="causal" in ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3821,8 +3828,11 @@ def backward(ctx, dout, *_args): dk=dk, dv=dv, ) - if not ctx.use_flash_attn_3: + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal + else: fa_backward_kwargs["rng_state"] = rng_state + fa_backward_kwargs["causal"] = causal flash_attn_bwd( dout, q, @@ -3831,7 +3841,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal=causal, **fa_backward_kwargs, ) From 71ffb7956021e4a8d87dbeb3f239ebbaefb9496e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 07:58:25 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/context_parallel.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c2880f7836..bd18298f21 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -937,7 +937,7 @@ def cp_p2p_fwd_flash_attn( elif section == "upper-triangle": max_seqlen_q_ = max_seqlen_q // 2 if section in ["lower-triangle", "upper-triangle"]: - if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 @@ -1189,7 +1189,7 @@ def cp_p2p_bwd_flash_attn( ): """Per-tile backward call of CP P2P with FlashAttention backend""" dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -1201,7 +1201,7 @@ def cp_p2p_bwd_flash_attn( softmax_lse__ = softmax_lse causal_ = False if section == "diagonal": - if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, 0) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -2989,7 +2989,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, ) - if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = window_size_per_step[i] elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] @@ -3210,9 +3210,7 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] - if ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = window_size_per_step[i] elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] From ce2c14b9eb9f001293d8963259fba17d69f7791f Mon Sep 17 00:00:00 2001 From: meichaoyang001 Date: Wed, 25 Feb 2026 16:43:44 +0800 Subject: [PATCH 5/6] Refactor Flash Attention 3 to use positional args instead of kwargs Replace keyword arguments with positional arguments in flash_attn_fwd and flash_attn_bwd to abstract away parameter naming differences (causal vs is_causal) between flash-attn versions. This provides a more robust interface that is resilient to future API changes in the flash-attn library. - Convert window_size_left, window_size_right, and causal parameters to positional args in both forward and backward functions - Eliminate version-specific parameter naming dependencies - Simplify compatibility handling across flash-attn v2.7.0+ variants Signed-off-by: meichaoyang001 --- .../dot_product_attention/context_parallel.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd18298f21..9946e9a0e8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1225,10 +1225,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) - if use_flash_attn_3: - fa_backward_kwargs["is_causal"] = causal_ - else: - fa_backward_kwargs["causal"] = causal_ + # if use_flash_attn_3: + # fa_backward_kwargs["is_causal"] = causal_ + # else: + # fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -1237,6 +1237,7 @@ def cp_p2p_bwd_flash_attn( out_part, softmax_lse__, *fa_backward_args_thd, + causal_, **fa_backward_kwargs, ) @@ -2999,7 +3000,7 @@ def forward( k_, v_, *fa_forward_args_thd, - causal=causal, + causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3215,10 +3216,10 @@ def backward(ctx, dout, *_args): elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] - if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type - else: - fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + # if ctx.use_flash_attn_3: + # fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + # else: + # fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3227,6 +3228,7 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, + ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3537,7 +3539,7 @@ def forward( k_part, v_part, *fa_forward_args_thd, - causal=causal, + causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3826,11 +3828,12 @@ def backward(ctx, dout, *_args): dk=dk, dv=dv, ) - if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = causal - else: + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state - fa_backward_kwargs["causal"] = causal + # fa_backward_kwargs["causal"] = causal + # else: + # fa_backward_kwargs["is_causal"] = causal + flash_attn_bwd( dout, q, @@ -3839,6 +3842,7 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, + causal, **fa_backward_kwargs, ) From de8483f78c70ae097d5db66bea9d9eb385524909 Mon Sep 17 00:00:00 2001 From: meichaoyang001 Date: Wed, 25 Feb 2026 17:40:50 +0800 Subject: [PATCH 6/6] Fix Flash Attention 3 backward API parameter naming Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v3 API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 --- .../dot_product_attention/context_parallel.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9946e9a0e8..10ba99595b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1225,10 +1225,10 @@ def cp_p2p_bwd_flash_attn( dk=dk, dv=dv, ) - # if use_flash_attn_3: - # fa_backward_kwargs["is_causal"] = causal_ - # else: - # fa_backward_kwargs["causal"] = causal_ + if use_flash_attn_3: + fa_backward_kwargs["is_causal"] = causal_ + else: + fa_backward_kwargs["causal"] = causal_ flash_attn_bwd( dout_part, q_part, @@ -1237,7 +1237,6 @@ def cp_p2p_bwd_flash_attn( out_part, softmax_lse__, *fa_backward_args_thd, - causal_, **fa_backward_kwargs, ) @@ -3000,7 +2999,7 @@ def forward( k_, v_, *fa_forward_args_thd, - causal, + causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3216,10 +3215,10 @@ def backward(ctx, dout, *_args): elif ctx.use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] - # if ctx.use_flash_attn_3: - # fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type - # else: - # fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + if ctx.use_flash_attn_3: + fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + else: + fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type flash_attn_bwd( dout_, q_, @@ -3228,7 +3227,6 @@ def backward(ctx, dout, *_args): out_, softmax_lse_per_step[i], *fa_backward_args_thd, - ctx.attn_mask_type, **fa_backward_kwargs, ) @@ -3539,7 +3537,7 @@ def forward( k_part, v_part, *fa_forward_args_thd, - causal, + causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: @@ -3830,9 +3828,9 @@ def backward(ctx, dout, *_args): ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state - # fa_backward_kwargs["causal"] = causal - # else: - # fa_backward_kwargs["is_causal"] = causal + fa_backward_kwargs["causal"] = causal + else: + fa_backward_kwargs["is_causal"] = causal flash_attn_bwd( dout, @@ -3842,7 +3840,6 @@ def backward(ctx, dout, *_args): out, softmax_lse, *fa_backward_args_thd, - causal, **fa_backward_kwargs, )