From 00b1aa97773884a15d2c7cb79397e65477856f24 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 27 Mar 2026 13:34:02 -0500 Subject: [PATCH 1/3] [Fix] Added functionality and test for determinism in Fused Attention. --- tests/jax/test_fused_attn.py | 141 +++++++++++++++++- .../common/fused_attn_rocm/fused_attn.cpp | 6 +- 2 files changed, 143 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 9fce78f6c..2ebc6206f 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -9,6 +9,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional, Dict +import os import random import jax @@ -26,7 +27,7 @@ from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from transformer_engine.jax import autocast -from transformer_engine.jax.sharding import MeshResource +from transformer_engine.jax.sharding import MeshResource, global_shard_guard from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -1252,3 +1253,141 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + + +@pytest.mark.skipif( + not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" +) +@pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), + pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + ], +) +@pytest.mark.parametrize( + "b, seq_len, h_q, h_kv, d", + [ + pytest.param(2, 256, 8, 8, 128, id="b2_s256_MHA"), + pytest.param(2, 2048, 8, 8, 128, id="b2_s2048_MHA"), + ], +) +def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d): + """ + Test that the CK fused attention backward pass in deterministic mode + produces bitwise-reproducible and numerically correct gradients. + + All seq_len values are >= 256 so that nsplits = ceil(s/kN0) > 1 + (kN0=128 for d<=128), ensuring the kernel actually exercises the + deterministic split-accumulator path rather than the trivial single-split. + """ + s = seq_len + dtype = jnp.bfloat16 + scaling_factor = 1.0 / sqrt(d) + + # Set deterministic mode before any TE calls so the flag is visible + # throughout backend selection and kernel dispatch. + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + # Verify the CK backend is selected, otherwise test is meaningless + backend = FusedAttnHelper( + True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type, + 0.0, h_q, h_kv, s, s, d, d, (-1, -1), + ).get_fused_attn_backend() + if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + pytest.skip("No fused attention backend available for this config") + assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( + f"Expected CK backend but got {backend}." + ) + try: + key = jax.random.PRNGKey(42) + q_key, k_key, v_key = jax.random.split(key, 3) + q = jax.random.normal(q_key, (b, s, h_q, d), dtype=dtype) + k = jax.random.normal(k_key, (b, s, h_kv, d), dtype=dtype) + v = jax.random.normal(v_key, (b, s, h_kv, d), dtype=dtype) + + if attn_mask_type == AttnMaskType.NO_MASK: + seq_desc = None + mask = None + else: + idx = jnp.arange(s) + causal_mask = idx[None, :] > idx[:, None] + seq_desc = jnp.broadcast_to(causal_mask[None, None], (b, 1, s, s)) + mask = seq_desc + + kwargs = dict( + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=0.0, + is_training=True, + qkv_layout=qkv_layout, + ) + + # Fused CK backward — run twice to check bitwise reproducibility + def fused_fn(q, k, v): + return customcall_fused_dpa( + q, k, v, None, seq_desc, None, **kwargs + ).astype(jnp.float32).sum() + + with global_shard_guard(MeshResource()): + _, grads1 = jax.value_and_grad(fused_fn, argnums=(0, 1, 2))(q, k, v) + fused_dq1 = np.array(grads1[0].block_until_ready()) + fused_dk1 = np.array(grads1[1].block_until_ready()) + fused_dv1 = np.array(grads1[2].block_until_ready()) + + _, grads2 = jax.value_and_grad(fused_fn, argnums=(0, 1, 2))(q, k, v) + fused_dq2 = np.array(grads2[0].block_until_ready()) + fused_dk2 = np.array(grads2[1].block_until_ready()) + fused_dv2 = np.array(grads2[2].block_until_ready()) + + # Bitwise reproducibility across consecutive runs + np.testing.assert_array_equal(fused_dq1, fused_dq2, err_msg="dQ not bitwise reproducible") + np.testing.assert_array_equal(fused_dk1, fused_dk2, err_msg="dK not bitwise reproducible") + np.testing.assert_array_equal(fused_dv1, fused_dv2, err_msg="dV not bitwise reproducible") + + # Numerical correctness vs unfused JAX reference + def ref_fn(q, k, v): + return jax_dpa(q, k, v, None, mask, None, **kwargs).astype(jnp.float32).sum() + + _, ref_grads = jax.value_and_grad(ref_fn, argnums=(0, 1, 2))(q, k, v) + ref_dq = ref_grads[0].block_until_ready() + ref_dk = ref_grads[1].block_until_ready() + ref_dv = ref_grads[2].block_until_ready() + + assert_allclose(jnp.array(fused_dq1), ref_dq, dtype=dtype) + assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype) + assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype) + finally: + os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) + + +@pytest.mark.skipif( + not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + ], +) +def test_deterministic_bwd_gqa(attn_mask_type): + """ + GQA variant of test_deterministic_bwd. + Only BSHD_BSHD_BSHD layout supports h_q != h_kv. + """ + test_deterministic_bwd( + qkv_layout=QKVLayout.BSHD_BSHD_BSHD, + attn_mask_type=attn_mask_type, + b=2, seq_len=2048, h_q=12, h_kv=4, d=128, + ) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..5f5f75793 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -488,7 +488,7 @@ void nvte_fused_attn_bwd_qkvpacked( attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, @@ -677,7 +677,7 @@ void nvte_fused_attn_bwd_kvpacked( attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, @@ -863,7 +863,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, From 2eec8ff16bf07ffc7be02edcb8fd6c8749ee16bf Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 27 Mar 2026 13:50:48 -0500 Subject: [PATCH 2/3] [Fix] Removed extra space and restore env var --- tests/jax/test_fused_attn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 2ebc6206f..f7dff5518 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1255,7 +1255,6 @@ def test_jax_new_rng(): runner.test_forward() - @pytest.mark.skipif( not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" ) @@ -1296,6 +1295,7 @@ def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d) # Set deterministic mode before any TE calls so the flag is visible # throughout backend selection and kernel dispatch. + _orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO") os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" # Verify the CK backend is selected, otherwise test is meaningless @@ -1368,7 +1368,10 @@ def ref_fn(q, k, v): assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype) assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype) finally: - os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) + if _orig_nondeterministic is None: + os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic @pytest.mark.skipif( From 6ce1b9d2ae5886fe18f068ce2bd0144025506f38 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 27 Mar 2026 14:11:31 -0500 Subject: [PATCH 3/3] [Fix] Addressed review comments, refactored. --- tests/jax/test_fused_attn.py | 123 ++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f7dff5518..b92942a74 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1255,40 +1255,23 @@ def test_jax_new_rng(): runner.test_forward() -@pytest.mark.skipif( - not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" -) -@pytest.mark.parametrize( - "qkv_layout", - [ - pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), - pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - ], -) -@pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), - ], -) -@pytest.mark.parametrize( - "b, seq_len, h_q, h_kv, d", - [ - pytest.param(2, 256, 8, 8, 128, id="b2_s256_MHA"), - pytest.param(2, 2048, 8, 8, 128, id="b2_s2048_MHA"), - ], -) -def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d): +def _run_deterministic_bwd_case( + qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d, check_numerical=None +): """ - Test that the CK fused attention backward pass in deterministic mode - produces bitwise-reproducible and numerically correct gradients. + Shared helper for deterministic backward tests. + + Verifies that the CK fused attention backward pass in deterministic mode + produces bitwise-reproducible gradients. Optionally checks numerical + correctness against an unfused JAX reference (O(s^2)); this is skipped for + large seq_len to keep CI fast. - All seq_len values are >= 256 so that nsplits = ceil(s/kN0) > 1 + All seq_len values should be >= 256 so that nsplits = ceil(s/kN0) > 1 (kN0=128 for d<=128), ensuring the kernel actually exercises the deterministic split-accumulator path rather than the trivial single-split. """ + if check_numerical is None: + check_numerical = seq_len <= 256 s = seq_len dtype = jnp.bfloat16 scaling_factor = 1.0 / sqrt(d) @@ -1315,14 +1298,20 @@ def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d) k = jax.random.normal(k_key, (b, s, h_kv, d), dtype=dtype) v = jax.random.normal(v_key, (b, s, h_kv, d), dtype=dtype) + # Build sequence descriptor via the non-deprecated SequenceDescriptor API. + # For NO_MASK every sequence is full-length; for CAUSAL the mask type + # alone drives masking inside fused_attn. + seqlens = jnp.full((b,), s, dtype=jnp.int32) + seq_desc = SequenceDescriptor.from_seqlens((seqlens, seqlens)) + + # The unfused JAX reference (jax_dpa) still takes a raw mask ndarray. if attn_mask_type == AttnMaskType.NO_MASK: - seq_desc = None mask = None else: idx = jnp.arange(s) - causal_mask = idx[None, :] > idx[:, None] - seq_desc = jnp.broadcast_to(causal_mask[None, None], (b, 1, s, s)) - mask = seq_desc + mask = jnp.broadcast_to( + (idx[None, :] > idx[:, None])[None, None], (b, 1, s, s) + ) kwargs = dict( attn_bias_type=AttnBiasType.NO_BIAS, @@ -1333,19 +1322,21 @@ def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d) qkv_layout=qkv_layout, ) - # Fused CK backward — run twice to check bitwise reproducibility + # Fused CK backward — JIT-compiled, run twice for bitwise reproducibility def fused_fn(q, k, v): return customcall_fused_dpa( q, k, v, None, seq_desc, None, **kwargs ).astype(jnp.float32).sum() + fused_val_grad = jit(jax.value_and_grad(fused_fn, argnums=(0, 1, 2))) + with global_shard_guard(MeshResource()): - _, grads1 = jax.value_and_grad(fused_fn, argnums=(0, 1, 2))(q, k, v) + _, grads1 = fused_val_grad(q, k, v) fused_dq1 = np.array(grads1[0].block_until_ready()) fused_dk1 = np.array(grads1[1].block_until_ready()) fused_dv1 = np.array(grads1[2].block_until_ready()) - _, grads2 = jax.value_and_grad(fused_fn, argnums=(0, 1, 2))(q, k, v) + _, grads2 = fused_val_grad(q, k, v) fused_dq2 = np.array(grads2[0].block_until_ready()) fused_dk2 = np.array(grads2[1].block_until_ready()) fused_dv2 = np.array(grads2[2].block_until_ready()) @@ -1355,18 +1346,20 @@ def fused_fn(q, k, v): np.testing.assert_array_equal(fused_dk1, fused_dk2, err_msg="dK not bitwise reproducible") np.testing.assert_array_equal(fused_dv1, fused_dv2, err_msg="dV not bitwise reproducible") - # Numerical correctness vs unfused JAX reference - def ref_fn(q, k, v): - return jax_dpa(q, k, v, None, mask, None, **kwargs).astype(jnp.float32).sum() + # Numerical correctness vs unfused JAX reference (O(s^2), skip for large s) + if check_numerical: + def ref_fn(q, k, v): + return jax_dpa(q, k, v, None, mask, None, **kwargs).astype(jnp.float32).sum() - _, ref_grads = jax.value_and_grad(ref_fn, argnums=(0, 1, 2))(q, k, v) - ref_dq = ref_grads[0].block_until_ready() - ref_dk = ref_grads[1].block_until_ready() - ref_dv = ref_grads[2].block_until_ready() + ref_val_grad = jit(jax.value_and_grad(ref_fn, argnums=(0, 1, 2))) + _, ref_grads = ref_val_grad(q, k, v) + ref_dq = ref_grads[0].block_until_ready() + ref_dk = ref_grads[1].block_until_ready() + ref_dv = ref_grads[2].block_until_ready() - assert_allclose(jnp.array(fused_dq1), ref_dq, dtype=dtype) - assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype) - assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype) + assert_allclose(jnp.array(fused_dq1), ref_dq, dtype=dtype) + assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype) + assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype) finally: if _orig_nondeterministic is None: os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) @@ -1374,6 +1367,36 @@ def ref_fn(q, k, v): os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic +@pytest.mark.skipif( + not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" +) +@pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), + pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + ], +) +@pytest.mark.parametrize( + "b, seq_len, h_q, h_kv, d", + [ + pytest.param(2, 256, 8, 8, 128, id="b2_s256_MHA"), + pytest.param(2, 2048, 8, 8, 128, id="b2_s2048_MHA"), + ], +) +def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d): + """Test CK deterministic backward: bitwise reproducibility + correctness.""" + _run_deterministic_bwd_case(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d) + + @pytest.mark.skipif( not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" ) @@ -1385,12 +1408,10 @@ def ref_fn(q, k, v): ], ) def test_deterministic_bwd_gqa(attn_mask_type): - """ - GQA variant of test_deterministic_bwd. - Only BSHD_BSHD_BSHD layout supports h_q != h_kv. - """ - test_deterministic_bwd( + """GQA variant: BSHD_BSHD_BSHD with h_q != h_kv.""" + _run_deterministic_bwd_case( qkv_layout=QKVLayout.BSHD_BSHD_BSHD, attn_mask_type=attn_mask_type, b=2, seq_len=2048, h_q=12, h_kv=4, d=128, + check_numerical=False, )