-
Notifications
You must be signed in to change notification settings - Fork 25
[TE] Enable deterministic mode for fused attention #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||||||||||||||||
| from functools import partial | ||||||||||||||||||||
| from math import sqrt | ||||||||||||||||||||
| from typing import Tuple, Optional, Dict | ||||||||||||||||||||
| import os | ||||||||||||||||||||
| import random | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import jax | ||||||||||||||||||||
|
|
@@ -26,7 +27,7 @@ | |||||||||||||||||||
|
|
||||||||||||||||||||
| from transformer_engine.jax.cpp_extensions.misc import is_hip_extension | ||||||||||||||||||||
| from transformer_engine.jax import autocast | ||||||||||||||||||||
| from transformer_engine.jax.sharding import MeshResource | ||||||||||||||||||||
| from transformer_engine.jax.sharding import MeshResource, global_shard_guard | ||||||||||||||||||||
| from transformer_engine.jax.attention import ( | ||||||||||||||||||||
| AttnBiasType, | ||||||||||||||||||||
| AttnMaskType, | ||||||||||||||||||||
|
|
@@ -1252,3 +1253,165 @@ def test_jax_new_rng(): | |||||||||||||||||||
| ) | ||||||||||||||||||||
| runner = FusedAttnRunner(**kwargs) | ||||||||||||||||||||
| runner.test_forward() | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _run_deterministic_bwd_case( | ||||||||||||||||||||
| qkv_layout, attn_mask_type, b, seq_len, h_q, h_kv, d, check_numerical=None | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Shared helper for deterministic backward tests. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Verifies that the CK fused attention backward pass in deterministic mode | ||||||||||||||||||||
| produces bitwise-reproducible gradients. Optionally checks numerical | ||||||||||||||||||||
| correctness against an unfused JAX reference (O(s^2)); this is skipped for | ||||||||||||||||||||
| large seq_len to keep CI fast. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| All seq_len values should be >= 256 so that nsplits = ceil(s/kN0) > 1 | ||||||||||||||||||||
| (kN0=128 for d<=128), ensuring the kernel actually exercises the | ||||||||||||||||||||
| deterministic split-accumulator path rather than the trivial single-split. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| if check_numerical is None: | ||||||||||||||||||||
| check_numerical = seq_len <= 256 | ||||||||||||||||||||
|
Comment on lines
+1273
to
+1274
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we skip checking the numerical for cases with seqlen<=256 |
||||||||||||||||||||
| s = seq_len | ||||||||||||||||||||
| dtype = jnp.bfloat16 | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's check for both bf16 and fp16 |
||||||||||||||||||||
| scaling_factor = 1.0 / sqrt(d) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Set deterministic mode before any TE calls so the flag is visible | ||||||||||||||||||||
| # throughout backend selection and kernel dispatch. | ||||||||||||||||||||
| _orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO") | ||||||||||||||||||||
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" | ||||||||||||||||||||
|
Comment on lines
+1281
to
+1282
|
||||||||||||||||||||
|
|
||||||||||||||||||||
AllenFarcas marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
| # Verify the CK backend is selected, otherwise test is meaningless | ||||||||||||||||||||
| backend = FusedAttnHelper( | ||||||||||||||||||||
| True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type, | ||||||||||||||||||||
| 0.0, h_q, h_kv, s, s, d, d, (-1, -1), | ||||||||||||||||||||
| ).get_fused_attn_backend() | ||||||||||||||||||||
| if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: | ||||||||||||||||||||
| pytest.skip("No fused attention backend available for this config") | ||||||||||||||||||||
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | ||||||||||||||||||||
| f"Expected CK backend but got {backend}." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+1291
to
+1293
|
||||||||||||||||||||
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | |
| f"Expected CK backend but got {backend}." | |
| ) | |
| if backend != NVTE_Fused_Attn_Backend.NVTE_CK: | |
| pytest.skip(f"Deterministic CK test requires CK backend, got {backend}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will indeed report it as a fail, so let's skip instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang
Copilot
AI
Mar 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MeshResource() is likely an invalid constructor call in this codebase. In this test suite, MeshResource is commonly instantiated with explicit axis/resource names (e.g. MeshResource('dp', 'cp', 'tp')), and calling it with no arguments may raise TypeError at runtime, preventing the test from running. Construct MeshResource with the expected arguments (or reuse the same default resource configuration used elsewhere in this file).
| with global_shard_guard(MeshResource()): | |
| with global_shard_guard(MeshResource('dp', 'cp', 'tp')): |
Copilot
AI
Mar 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.
| if _orig_nondeterministic is None: | |
| os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) | |
| else: | |
| os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic | |
| monkeypatch = pytest.MonkeyPatch() | |
| if _orig_nondeterministic is None: | |
| monkeypatch.delenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", raising=False) | |
| else: | |
| monkeypatch.setenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", _orig_nondeterministic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not restrict to NO_MASK or CAUSAL, let's add causal, padding, padding causal, padding causal bottom right as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, extend to nonGQA cases as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, check some sequence packing cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?