Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional, Dict
import os
import random

import jax
Expand All @@ -26,7 +27,7 @@

from transformer_engine.jax.cpp_extensions.misc import is_hip_extension
from transformer_engine.jax import autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
Expand Down Expand Up @@ -1252,3 +1253,165 @@ def test_jax_new_rng():
)
runner = FusedAttnRunner(**kwargs)
runner.test_forward()


def _run_deterministic_bwd_case(
qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d, check_numerical=None
):
"""
Shared helper for deterministic backward tests.

Verifies that the CK fused attention backward pass in deterministic mode
produces bitwise-reproducible gradients. Optionally checks numerical
correctness against an unfused JAX reference (O(s^2)); this is skipped for
large seq_len to keep CI fast.

All seq_len values should be >= 256 so that nsplits = ceil(s/kN0) > 1
(kN0=128 for d<=128), ensuring the kernel actually exercises the
deterministic split-accumulator path rather than the trivial single-split.
"""
if check_numerical is None:
check_numerical = seq_len <= 256
Comment on lines +1273 to +1274
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we skip checking the numerical for cases with seqlen<=256

s = seq_len
dtype = jnp.bfloat16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check for both bf16 and fp16

scaling_factor = 1.0 / sqrt(d)

# Set deterministic mode before any TE calls so the flag is visible
# throughout backend selection and kernel dispatch.
_orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
Comment on lines +1281 to +1282
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, copilot gave a good comment. You will need to cache the outside env values and reset it after we finish this pytest


# Verify the CK backend is selected, otherwise test is meaningless
backend = FusedAttnHelper(
True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type,
0.0, h_q, h_kv, s, s, d, d, (-1, -1),
).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("No fused attention backend available for this config")
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
Comment on lines +1291 to +1293
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hard assert makes the test brittle if another valid HIP backend is selected in the future (or on some configurations), even though the deterministic behavior under test could still be correct. Consider replacing the assert with pytest.skip(...) when backend != NVTE_CK, or explicitly constraining the test inputs to only configurations that can select CK deterministically.

Suggested change
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
if backend != NVTE_Fused_Attn_Backend.NVTE_CK:
pytest.skip(f"Deterministic CK test requires CK backend, got {backend}.")

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will indeed report it as a fail, so let's skip instead.

Comment on lines +1285 to +1293
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang

try:
key = jax.random.PRNGKey(42)
q_key, k_key, v_key = jax.random.split(key, 3)
q = jax.random.normal(q_key, (b, s, h_q, d), dtype=dtype)
k = jax.random.normal(k_key, (b, s, h_kv, d), dtype=dtype)
v = jax.random.normal(v_key, (b, s, h_kv, d), dtype=dtype)

# Build sequence descriptor via the non-deprecated SequenceDescriptor API.
# For NO_MASK every sequence is full-length; for CAUSAL the mask type
# alone drives masking inside fused_attn.
seqlens = jnp.full((b,), s, dtype=jnp.int32)
seq_desc = SequenceDescriptor.from_seqlens((seqlens, seqlens))

# The unfused JAX reference (jax_dpa) still takes a raw mask ndarray.
if attn_mask_type == AttnMaskType.NO_MASK:
mask = None
else:
idx = jnp.arange(s)
mask = jnp.broadcast_to(
(idx[None, :] > idx[:, None])[None, None], (b, 1, s, s)
)

kwargs = dict(
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=0.0,
is_training=True,
qkv_layout=qkv_layout,
)

# Fused CK backward — JIT-compiled, run twice for bitwise reproducibility
def fused_fn(q, k, v):
return customcall_fused_dpa(
q, k, v, None, seq_desc, None, **kwargs
).astype(jnp.float32).sum()

fused_val_grad = jit(jax.value_and_grad(fused_fn, argnums=(0, 1, 2)))

with global_shard_guard(MeshResource()):
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MeshResource() is likely an invalid constructor call in this codebase. In this test suite, MeshResource is commonly instantiated with explicit axis/resource names (e.g. MeshResource('dp', 'cp', 'tp')), and calling it with no arguments may raise TypeError at runtime, preventing the test from running. Construct MeshResource with the expected arguments (or reuse the same default resource configuration used elsewhere in this file).

Suggested change
with global_shard_guard(MeshResource()):
with global_shard_guard(MeshResource('dp', 'cp', 'tp')):

Copilot uses AI. Check for mistakes.
_, grads1 = fused_val_grad(q, k, v)
fused_dq1 = np.array(grads1[0].block_until_ready())
fused_dk1 = np.array(grads1[1].block_until_ready())
fused_dv1 = np.array(grads1[2].block_until_ready())

_, grads2 = fused_val_grad(q, k, v)
fused_dq2 = np.array(grads2[0].block_until_ready())
fused_dk2 = np.array(grads2[1].block_until_ready())
fused_dv2 = np.array(grads2[2].block_until_ready())

# Bitwise reproducibility across consecutive runs
np.testing.assert_array_equal(fused_dq1, fused_dq2, err_msg="dQ not bitwise reproducible")
np.testing.assert_array_equal(fused_dk1, fused_dk2, err_msg="dK not bitwise reproducible")
np.testing.assert_array_equal(fused_dv1, fused_dv2, err_msg="dV not bitwise reproducible")

# Numerical correctness vs unfused JAX reference (O(s^2), skip for large s)
if check_numerical:
def ref_fn(q, k, v):
return jax_dpa(q, k, v, None, mask, None, **kwargs).astype(jnp.float32).sum()

ref_val_grad = jit(jax.value_and_grad(ref_fn, argnums=(0, 1, 2)))
_, ref_grads = ref_val_grad(q, k, v)
ref_dq = ref_grads[0].block_until_ready()
ref_dk = ref_grads[1].block_until_ready()
ref_dv = ref_grads[2].block_until_ready()

assert_allclose(jnp.array(fused_dq1), ref_dq, dtype=dtype)
assert_allclose(jnp.array(fused_dk1), ref_dk, dtype=dtype)
assert_allclose(jnp.array(fused_dv1), ref_dv, dtype=dtype)
finally:
if _orig_nondeterministic is None:
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic
Comment on lines +1364 to +1367
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.

Suggested change
if _orig_nondeterministic is None:
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic
monkeypatch = pytest.MonkeyPatch()
if _orig_nondeterministic is None:
monkeypatch.delenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", raising=False)
else:
monkeypatch.setenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", _orig_nondeterministic)

Copilot uses AI. Check for mistakes.


@pytest.mark.skipif(
not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware"
)
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
],
)
@pytest.mark.parametrize(
"b, seq_len, h_q, h_kv, d",
[
pytest.param(2, 256, 8, 8, 128, id="b2_s256_MHA"),
pytest.param(2, 2048, 8, 8, 128, id="b2_s2048_MHA"),
],
)
def test_deterministic_bwd(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d):
"""Test CK deterministic backward: bitwise reproducibility + correctness."""
_run_deterministic_bwd_case(qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d)


@pytest.mark.skipif(
not is_hip_extension(), reason="CK deterministic backward only applies to AMD hardware"
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not restrict to NO_MASK or CAUSAL, let's add causal, padding, padding causal, padding causal bottom right as well

],
)
def test_deterministic_bwd_gqa(attn_mask_type):
"""GQA variant: BSHD_BSHD_BSHD with h_q != h_kv."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, extend to nonGQA cases as well

_run_deterministic_bwd_case(
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
attn_mask_type=attn_mask_type,
b=2, seq_len=2048, h_q=12, h_kv=4, d=128,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, check some sequence packing cases

check_numerical=False,
)
6 changes: 3 additions & 3 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ void nvte_fused_attn_bwd_qkvpacked(
attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right,
false, // TODO: enable deterministic after CK team show us how
deterministic,
input_QKV, input_O, input_dO, input_Bias, output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_cu_seqlens_padded,
Expand Down Expand Up @@ -677,7 +677,7 @@ void nvte_fused_attn_bwd_kvpacked(
attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right,
false, // TODO: enable deterministic after CK team show us how
deterministic,
input_Q, input_KV, input_O, input_dO, input_Bias,
output_S,
output_dQ, output_dKV, output_dBias,
Expand Down Expand Up @@ -863,7 +863,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type,
window_size_left, window_size_right,
false, // TODO: enable deterministic after CK team show us how
deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_Bias,
output_S,
output_dQ, output_dK, output_dV, output_dBias,
Expand Down
Loading