diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 9fce78f6c..b92942a74 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -9,6 +9,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional, Dict +import os import random import jax @@ -26,7 +27,7 @@ from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from transformer_engine.jax import autocast -from transformer_engine.jax.sharding import MeshResource +from transformer_engine.jax.sharding import MeshResource, global_shard_guard from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -1252,3 +1253,165 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + +def _run_deterministic_bwd_case( + qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d, check_numerical=None +): + """ + Shared helper for deterministic backward tests. + + Verifies that the CK fused attention backward pass in deterministic mode + produces bitwise-reproducible gradients. Optionally checks numerical + correctness against an unfused JAX reference (O(s^2)); this is skipped for + large seq_len to keep CI fast. + + All seq_len values should be >= 256 so that nsplits = ceil(s/kN0) > 1 + (kN0=128 for d<=128), ensuring the kernel actually exercises the + deterministic split-accumulator path rather than the trivial single-split. + """ + if check_numerical is None: + check_numerical = seq_len <= 256 + s = seq_len + dtype = jnp.bfloat16 + scaling_factor = 1.0 / sqrt(d) + + # Set deterministic mode before any TE calls so the flag is visible + # throughout backend selection and kernel dispatch. + _orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO") + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + # Verify the CK backend is selected, otherwise test is meaningless + backend = FusedAttnHelper( + True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type, + 0.0, h_q, h_kv, s, s, d, d, (-1, -1), + ).get_fused_attn_backend() + if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + pytest.skip("No fused attention backend available for this config") + assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( + f"Expected CK backend but got {backend}." + ) + try: + key = jax.random.PRNGKey(42) + q_key, k_key, v_key = jax.random.split(key, 3) + q = jax.random.normal(q_key, (b, s, h_q, d), dtype=dtype) + k = jax.random.normal(k_key, (b, s, h_kv, d), dtype=dtype) + v = jax.random.normal(v_key, (b, s, h_kv, d), dtype=dtype) + + # Build sequence descriptor via the non-deprecated SequenceDescriptor API. + # For NO_MASK every sequence is full-length; for CAUSAL the mask type + # alone drives masking inside fused_attn. + seqlens = jnp.full((b,), s, dtype=jnp.int32) + seq_desc = SequenceDescriptor.from_seqlens((seqlens, seqlens)) + + # The unfused JAX reference (jax_dpa) still takes a raw mask ndarray. + if attn_mask_type == AttnMaskType.NO_MASK: + mask = None + else: + idx = jnp.arange(s) + mask = jnp.broadcast_to( + (idx[None, :] > idx[:, None])[None, None], (b, 1, s, s) + ) + + kwargs = dict( + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=0.0, + is_training=True, + qkv_layout=qkv_layout, + ) + + # Fused CK backward — JIT-compiled, run twice for bitwise reproducibility + def fused_fn(q, k, v): + return customcall_fused_dpa( + q, k, v, None, seq_desc, None, **kwargs + ).astype(jnp.float32).sum() + + fused_val_grad = jit(jax.value_and_grad(fused_fn, argnums=(0, 1, 2))) + + with global_shard_guard(MeshResource()): + _, grads1 = fused_val_grad(q, k, v) + fused_dq1 = np.array(grads1[0].block_until_ready()) + fused_dk1 = np.array(grads1[1].block_until_ready()) + fused_dv1 = np.array(grads1[2].block_until_ready()) + + _, grads2 = fused_val_grad(q, k, v) + fused_dq2 = np.array(grads2[0].block_until_ready()) + fused_dk2 = np.array(grads2[1].block_until_ready()) + fused_dv2 = np.array(grads2[2].block_until_ready()) + + # Bitwise reproducibility across consecutive runs + np.testing.assert_array_equal(fused_dq1, fused_dq2, err_msg="dQ not bitwise reproducible") + np.testing.assert_array_equal(fused_dk1, fused_dk2, err_msg="dK not bitwise reproducible") + np.testing.assert_array_equal(fused_dv1, fused_dv2, err_msg="dV not bitwise reproducible") + + # Numerical correctness vs unfused JAX reference (O(s^2), skip for large s) + if check_numerical: + def ref_fn(q, k, v): + return jax_dpa(q, k, v, None, mask, None, **kwargs).astype(jnp.float32).sum() + + ref_val_grad = jit(jax.value_and_grad(ref_fn, argnums=(0, 1, 2))) + _, ref_grads = ref_val_grad(q, k, v) + ref_dq = ref_grads[0].block_until_ready() + ref_dk = ref_grads[1].block_until_ready() + ref_dv = ref_grads[2].block_until_ready() + + assert_allclose(jnp.array(fused_dq1), ref_dq, dtype=dtype) + assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype) + assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype) + finally: + if _orig_nondeterministic is None: + os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic + + +@pytest.mark.skipif( + not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" +) +@pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), + pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + ], +) +@pytest.mark.parametrize( + "b, seq_len, h_q, h_kv, d", + [ + pytest.param(2, 256, 8, 8, 128, id="b2_s256_MHA"), + pytest.param(2, 2048, 8, 8, 128, id="b2_s2048_MHA"), + ], +) +def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d): + """Test CK deterministic backward: bitwise reproducibility + correctness.""" + _run_deterministic_bwd_case(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d) + + +@pytest.mark.skipif( + not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware" +) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + ], +) +def test_deterministic_bwd_gqa(attn_mask_type): + """GQA variant: BSHD_BSHD_BSHD with h_q != h_kv.""" + _run_deterministic_bwd_case( + qkv_layout=QKVLayout.BSHD_BSHD_BSHD, + attn_mask_type=attn_mask_type, + b=2, seq_len=2048, h_q=12, h_kv=4, d=128, + check_numerical=False, + ) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..5f5f75793 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -488,7 +488,7 @@ void nvte_fused_attn_bwd_qkvpacked( attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, @@ -677,7 +677,7 @@ void nvte_fused_attn_bwd_kvpacked( attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, @@ -863,7 +863,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - false, // TODO: enable deterministic after CK team show us how + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias,