Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions qa/L2_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
# Make mnist and encoder tests run-to-run deterministic for stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# Note: mnist intentionally does NOT set --xla_gpu_deterministic_ops because it
# significantly slows down small conv/GEMM kernels and was causing CI timeouts.
# The mnist verify() already uses a tail-window min/max with relaxed thresholds
# to be robust to run-to-run numerical noise.
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"

pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicate --xla_gpu_deterministic_ops accumulation in XLA_FLAGS

XLA_FLAGS is appended with --xla_gpu_deterministic_ops on line 42 and then again on line 45, so the second encoder run (NVTE_JAX_CUSTOM_CALLS=false) sees the flag twice in the environment variable. This was a pre-existing pattern carried forward by this PR. XLA likely tolerates duplicate flags, but it's worth using a guard to avoid the accumulation:

Suggested change
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
export XLA_FLAGS="${XLA_FLAGS:+$XLA_FLAGS }--xla_gpu_deterministic_ops"

Or simply remove the redundant second export since the flag is already set from line 42.

Expand Down
Loading