[TE] Enable deterministic mode for fused attention#508
[TE] Enable deterministic mode for fused attention#508AllenFarcas wants to merge 3 commits intodevfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Enables deterministic mode propagation for ROCm fused-attention backward (CK backend) and adds JAX coverage to validate bitwise reproducibility and gradient correctness when non-deterministic algorithms are disallowed.
Changes:
- Forward the
deterministicflag from NVTE ROCm fused-attn backward entrypoints into CK backend calls. - Add JAX tests that (on HIP/AMD) verify backward gradients are bitwise reproducible across runs and match an unfused JAX reference.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/common/fused_attn_rocm/fused_attn.cpp | Passes the deterministic argument into CK fused-attn backward implementations (qkvpacked/kvpacked/separate). |
| tests/jax/test_fused_attn.py | Adds HIP-only deterministic-backward tests and imports global_shard_guard to ensure mesh resource context is set. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| fused_val_grad = jit(jax.value_and_grad(fused_fn, argnums=(0, 1, 2))) | ||
|
|
||
| with global_shard_guard(MeshResource()): |
There was a problem hiding this comment.
MeshResource() is likely an invalid constructor call in this codebase. In this test suite, MeshResource is commonly instantiated with explicit axis/resource names (e.g. MeshResource('dp', 'cp', 'tp')), and calling it with no arguments may raise TypeError at runtime, preventing the test from running. Construct MeshResource with the expected arguments (or reuse the same default resource configuration used elsewhere in this file).
| with global_shard_guard(MeshResource()): | |
| with global_shard_guard(MeshResource('dp', 'cp', 'tp')): |
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | ||
| f"Expected CK backend but got {backend}." | ||
| ) |
There was a problem hiding this comment.
This hard assert makes the test brittle if another valid HIP backend is selected in the future (or on some configurations), even though the deterministic behavior under test could still be correct. Consider replacing the assert with pytest.skip(...) when backend != NVTE_CK, or explicitly constraining the test inputs to only configurations that can select CK deterministically.
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | |
| f"Expected CK backend but got {backend}." | |
| ) | |
| if backend != NVTE_Fused_Attn_Backend.NVTE_CK: | |
| pytest.skip(f"Deterministic CK test requires CK backend, got {backend}.") |
There was a problem hiding this comment.
This will indeed report it as a fail, so let's skip instead.
| _orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO") | ||
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" |
There was a problem hiding this comment.
This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.
There was a problem hiding this comment.
Yeah, copilot gave a good comment. You will need to cache the outside env values and reset it after we finish this pytest
| if _orig_nondeterministic is None: | ||
| os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) | ||
| else: | ||
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic |
There was a problem hiding this comment.
This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.
| if _orig_nondeterministic is None: | |
| os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) | |
| else: | |
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic | |
| monkeypatch = pytest.MonkeyPatch() | |
| if _orig_nondeterministic is None: | |
| monkeypatch.delenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", raising=False) | |
| else: | |
| monkeypatch.setenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", _orig_nondeterministic) |
|
Unless we want to support non-deterministic CK only for the JAX integration, we should probably also add some tests to the pytorch integration side since it'll be enabled there too. Also I think you still need to adjust TransformerEngine/transformer_engine/pytorch/attention/dot_product_attention/utils.py Lines 1070 to 1078 in 82617fe |
wangye805
left a comment
There was a problem hiding this comment.
BTW, add some deterministic testcases in pytorch side as well
| if check_numerical is None: | ||
| check_numerical = seq_len <= 256 |
There was a problem hiding this comment.
Why do we skip checking the numerical for cases with seqlen<=256
| from transformer_engine.jax.cpp_extensions.misc import is_hip_extension | ||
| from transformer_engine.jax import autocast | ||
| from transformer_engine.jax.sharding import MeshResource | ||
| from transformer_engine.jax.sharding import MeshResource, global_shard_guard |
| if check_numerical is None: | ||
| check_numerical = seq_len <= 256 | ||
| s = seq_len | ||
| dtype = jnp.bfloat16 |
There was a problem hiding this comment.
Let's check for both bf16 and fp16
| _orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO") | ||
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" |
There was a problem hiding this comment.
Yeah, copilot gave a good comment. You will need to cache the outside env values and reset it after we finish this pytest
| backend = FusedAttnHelper( | ||
| True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type, | ||
| 0.0, h_q, h_kv, s, s, d, d, (-1, -1), | ||
| ).get_fused_attn_backend() | ||
| if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: | ||
| pytest.skip("No fused attention backend available for this config") | ||
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | ||
| f"Expected CK backend but got {backend}." | ||
| ) |
There was a problem hiding this comment.
Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang
| "attn_mask_type", | ||
| [ | ||
| pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), | ||
| pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), |
There was a problem hiding this comment.
Let's not restrict to NO_MASK or CAUSAL, let's add causal, padding, padding causal, padding causal bottom right as well
| ], | ||
| ) | ||
| def test_deterministic_bwd_gqa(attn_mask_type): | ||
| """GQA variant: BSHD_BSHD_BSHD with h_q != h_kv.""" |
There was a problem hiding this comment.
Also, extend to nonGQA cases as well
| _run_deterministic_bwd_case( | ||
| qkv_layout=QKVLayout.BSHD_BSHD_BSHD, | ||
| attn_mask_type=attn_mask_type, | ||
| b=2, seq_len=2048, h_q=12, h_kv=4, d=128, |
There was a problem hiding this comment.
Also, check some sequence packing cases
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes https://github.com/ROCm/frameworks-internal/issues/15875
Type of change
Changes
Checklist: