Skip to content

[TE] Enable deterministic mode for fused attention#508

Open
AllenFarcas wants to merge 3 commits intodevfrom
alfarcas/aima60-fix
Open

[TE] Enable deterministic mode for fused attention#508
AllenFarcas wants to merge 3 commits intodevfrom
alfarcas/aima60-fix

Conversation

@AllenFarcas
Copy link
Copy Markdown
Contributor

@AllenFarcas AllenFarcas commented Mar 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes https://github.com/ROCm/frameworks-internal/issues/15875

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added deterministic functionality to fused attention
  • Added test for the introduced deterministic functionality

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables deterministic mode propagation for ROCm fused-attention backward (CK backend) and adds JAX coverage to validate bitwise reproducibility and gradient correctness when non-deterministic algorithms are disallowed.

Changes:

  • Forward the deterministic flag from NVTE ROCm fused-attn backward entrypoints into CK backend calls.
  • Add JAX tests that (on HIP/AMD) verify backward gradients are bitwise reproducible across runs and match an unfused JAX reference.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
transformer_engine/common/fused_attn_rocm/fused_attn.cpp Passes the deterministic argument into CK fused-attn backward implementations (qkvpacked/kvpacked/separate).
tests/jax/test_fused_attn.py Adds HIP-only deterministic-backward tests and imports global_shard_guard to ensure mesh resource context is set.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


fused_val_grad = jit(jax.value_and_grad(fused_fn, argnums=(0, 1, 2)))

with global_shard_guard(MeshResource()):
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MeshResource() is likely an invalid constructor call in this codebase. In this test suite, MeshResource is commonly instantiated with explicit axis/resource names (e.g. MeshResource('dp', 'cp', 'tp')), and calling it with no arguments may raise TypeError at runtime, preventing the test from running. Construct MeshResource with the expected arguments (or reuse the same default resource configuration used elsewhere in this file).

Suggested change
with global_shard_guard(MeshResource()):
with global_shard_guard(MeshResource('dp', 'cp', 'tp')):

Copilot uses AI. Check for mistakes.
Comment on lines +1291 to +1293
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hard assert makes the test brittle if another valid HIP backend is selected in the future (or on some configurations), even though the deterministic behavior under test could still be correct. Consider replacing the assert with pytest.skip(...) when backend != NVTE_CK, or explicitly constraining the test inputs to only configurations that can select CK deterministically.

Suggested change
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
if backend != NVTE_Fused_Attn_Backend.NVTE_CK:
pytest.skip(f"Deterministic CK test requires CK backend, got {backend}.")

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will indeed report it as a fail, so let's skip instead.

Comment on lines +1281 to +1282
_orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, copilot gave a good comment. You will need to cache the outside env values and reset it after we finish this pytest

Comment on lines +1364 to +1367
if _orig_nondeterministic is None:
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates a process-global environment variable. Even though it’s restored in finally, using pytest’s monkeypatch fixture (e.g., monkeypatch.setenv / monkeypatch.delenv) would be more robust and idiomatic, and reduces the risk of state leaking if this helper evolves (e.g., added early returns) or is reused elsewhere.

Suggested change
if _orig_nondeterministic is None:
os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None)
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = _orig_nondeterministic
monkeypatch = pytest.MonkeyPatch()
if _orig_nondeterministic is None:
monkeypatch.delenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", raising=False)
else:
monkeypatch.setenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", _orig_nondeterministic)

Copilot uses AI. Check for mistakes.
@AllenFarcas AllenFarcas added the ci-level 1 CI test level 1 label Mar 27, 2026
@Micky774
Copy link
Copy Markdown
Contributor

Micky774 commented Mar 27, 2026

Unless we want to support non-deterministic CK only for the JAX integration, we should probably also add some tests to the pytorch integration side since it'll be enabled there too.

Also I think you still need to adjust

# TODO: remove the filtering after ck team tells us how to enable more deterministic bwd kernels
if use_fused_attention and deterministic and IS_HIP_EXTENSION:
if (
fused_attention_backend == FusedAttnBackend["CK"]
and is_training
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
fused_attention_backend = None #TODO: switch to AOTriton when supported

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, add some deterministic testcases in pytorch side as well

Comment on lines +1273 to +1274
if check_numerical is None:
check_numerical = seq_len <= 256
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we skip checking the numerical for cases with seqlen<=256

from transformer_engine.jax.cpp_extensions.misc import is_hip_extension
from transformer_engine.jax import autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

if check_numerical is None:
check_numerical = seq_len <= 256
s = seq_len
dtype = jnp.bfloat16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check for both bf16 and fp16

Comment on lines +1281 to +1282
_orig_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, copilot gave a good comment. You will need to cache the outside env values and reset it after we finish this pytest

Comment on lines +1285 to +1293
backend = FusedAttnHelper(
True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type,
0.0, h_q, h_kv, s, s, d, d, (-1, -1),
).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("No fused attention backend available for this config")
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang

"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not restrict to NO_MASK or CAUSAL, let's add causal, padding, padding causal, padding causal bottom right as well

],
)
def test_deterministic_bwd_gqa(attn_mask_type):
"""GQA variant: BSHD_BSHD_BSHD with h_q != h_kv."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, extend to nonGQA cases as well

_run_deterministic_bwd_case(
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
attn_mask_type=attn_mask_type,
b=2, seq_len=2048, h_q=12, h_kv=4, d=128,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, check some sequence packing cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants