Skip to content

Commit 4537cce

Browse files
Updated XLA_FLAGS in ci/jax.sh
1 parent c737072 commit 4537cce

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

ci/jax.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ run_test_config() {
5858
run_default_fa 1 test_custom_call_compute.py
5959
run_default_fa 1 test_functions.py
6060
run 1 test_fused_attn.py
61+
XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled
6162
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
6263
run_default_fa 1 test_helper.py
6364
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention

tests/jax/test_fused_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,9 @@ def test_jax_new_rng():
12621262
@pytest.fixture
12631263
def ck_smallseq_env(monkeypatch):
12641264
"""Enable CK small-seq path and disable XLA GPU graphs for these tests."""
1265+
if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""):
1266+
pytest.skip("Run with XLA_FLAGS='--xla_gpu_graph_level=0' pytest ...")
12651267
monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1")
1266-
monkeypatch.setenv("XLA_FLAGS", "--xla_gpu_graph_level=0")
12671268
yield
12681269

12691270
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"])

0 commit comments

Comments
 (0)