From e0a5429883e1de3c96991e01563f8146eba44906 Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Fri, 10 Oct 2025 15:04:11 -0400 Subject: [PATCH 01/21] initial testing --- contributed/batch_invariance/README.md | 141 ++++++ .../test_batch_invariance_nki.py | 423 ++++++++++++++++++ 2 files changed, 564 insertions(+) create mode 100644 contributed/batch_invariance/README.md create mode 100644 contributed/batch_invariance/test_batch_invariance_nki.py diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md new file mode 100644 index 0000000..69f11b3 --- /dev/null +++ b/contributed/batch_invariance/README.md @@ -0,0 +1,141 @@ +# NKI Batch Invariance Test + +Testing whether NKI's tile size constraints protect against batch-dependent non-determinism in matrix multiplication. + +## Hypothesis + +**NKI achieves batch invariance by default due to hardware tile constraints.** + +Unlike CUDA/PyTorch, where batch size can influence the K-dimension reduction strategy (e.g., switching to split-K for better parallelism when M is small), NKI's hardware constraints enforce fixed tile sizes that decouple batch size from reduction order. + +### Key Protection Mechanisms + +1. **K is the reduction axis, not the batch axis (M)** + - Reduction happens over K (contraction dimension) + - M (batch) loop is outer, K loop is inner + - Changing M doesn't affect K iteration count + +2. **Hardware constraints enforce fixed tile sizes** + - Tensor Engine limits: P-dim ≤ 128, free-dim ≤ 512 + - Forces compile-time constants (e.g., K_TILE=128) + - Prevents runtime adaptation based on batch size + +3. **Potential vulnerability: Split-K** + - NKI *could* split along K when M is small (like CUDA does) + - This would couple M and K reduction strategy + - Our tests verify this doesn't happen automatically + +## Test Design + +Replicated [Thinking Machines' batch invariance test](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/): + +Instance_type: `inf2.xlarge` +AMI ID: `ami-0ec4ab14b1c5a10f2` +AMI NAME: `Deep Learning AMI Neuron (Ubuntu 22.04) 20250919` +```python +# CUDA shows non-determinism: +out1 = torch.mm(a[:1], b) # M=1 +out2 = torch.mm(a, b)[:1] # M=2048 +# Result: out1 ≠ out2 (diff: 1669.25) + +# NKI test: +out1 = matmul_nki(a[:128], b)[0] # M=128 +out2 = matmul_nki(a, b)[0] # M=2048 +# Result: out1 == out2 (diff: 0.0) ✓ +``` + +## Results + +### Test 1: M_TILE Variation (64 vs 128) +``` +M_TILE=64 → Result: [9664., 9600., ...] +M_TILE=128 → Result: [9664., 9600., ...] +Max difference: 0.0 ✓ INVARIANT +``` +**Conclusion:** Batch tiling strategy doesn't affect results. + +### Test 2: M (Batch Size) Variation (128 vs 2048) +``` +M=128 → Result: [9664., 9600., ...] +M=2048 → Result: [9664., 9600., ...] +Max difference: 0.0 ✓ INVARIANT +``` +**Conclusion:** True batch invariance achieved. Same element produces identical results regardless of batch size. + +### Test 3: K_TILE Variation (64 vs 128) - Simulated Dynamic Tiling +``` +K_TILE=128 → Result: [9664., 9600., ...] (32 iterations) +K_TILE=64 → Result: [9664., 9600., ...] (64 iterations) +Max difference: 256.0 ✓ VARIANT (expected) +``` +**Conclusion:** Reduction order matters. Different K_TILE → different accumulation order → different floating-point results. This simulates what CUDA does when it adapts K strategy based on batch size. + +### Test 4: Loop Iterator (affine_range vs sequential_range) +``` +affine_range → Result: [9664., 9600., ...] +sequential_range → Result: [9664., 9600., ...] +Max difference: 0.0 ✓ INVARIANT +``` +**Conclusion:** Loop iterator type is a compiler hint; doesn't affect numerical output. + +### Test 5: Precision Impact (bfloat16 vs float32) +``` +bfloat16 K_TILE diff: 256.0 (2.67% relative error) +float32 K_TILE diff: 15.125 (0.091% relative error) +Amplification: 16.9x +``` +**Conclusion:** Lower precision amplifies accumulation order effects. bfloat16's 7-bit mantissa shows 17x larger differences than float32's 23-bit mantissa. + +### Test 6: Consistency Check +``` +Run 1: 256.0 +Run 2: 256.0 +Run 3: 256.0 +✓ FULLY DETERMINISTIC +``` +**Conclusion:** The K_TILE difference is consistent and repeatable, not random. + +## Key Findings + +### ✅ Hypothesis Confirmed + +**NKI IS BATCH INVARIANT** +- M_TILE doesn't affect results (batch tiling invariant) +- M (batch size) doesn't affect results (true batch invariance) +- K_TILE DOES affect results (reduction order matters) +- But K_TILE is a compile-time constant → fully deterministic + +### 📊 Comparison: NKI vs CUDA + +| Aspect | CUDA | NKI | +|--------|------|-----| +| Batch size affects K reduction? | ✗ Yes (split-K adaptation) | ✅ No (fixed K_TILE) | +| Run-to-run deterministic? | ✗ No (varies ~1669) | ✅ Yes (always identical) | +| K_TILE matters? | ✅ Yes | ✅ Yes | +| Tile size constraints? | Flexible | Hardware-enforced (≤128/512) | + +### 🔬 Why NKI Wins + +1. **M/K decoupling:** Batch loop (M) is outer, reduction loop (K) is inner. Changing batch size doesn't affect K iteration count. + +2. **Hardware constraints as a feature:** Tensor Engine limits force compile-time K_TILE constants, preventing runtime adaptation. + +3. **No automatic split-K:** NKI doesn't dynamically switch to split-K based on batch size. You'd need to write a separate kernel. + +## Implications + +**For LLM Inference:** +- Batch-invariant by default (no special kernels needed like Thinking Machines built for CUDA) +- Deterministic sampling at temperature=0 (if K_TILE is fixed) +- True on-policy RL possible (identical numerics between training and inference) + +**Caveats:** +- K_TILE variation causes 2.67% relative error in bfloat16 (acceptable for most LLM use cases) +- Must use consistent K_TILE across kernels for bitwise reproducibility +- Lower precision (bfloat16) amplifies accumulation order effects 17x vs float32 + +## Conclusion + +NKI's tile size constraints, enforced by hardware limitations, provide batch invariance as an inherent property rather than requiring specialized implementations. The decoupling of batch size (M) from reduction strategy (K_TILE) ensures that the same element produces identical results regardless of the batch it's computed in. + +**Bottom line:** CUDA varies K reduction order *unpredictably* based on batch size. NKI keeps it *fixed* based on compile-time K_TILE. That's the win. diff --git a/contributed/batch_invariance/test_batch_invariance_nki.py b/contributed/batch_invariance/test_batch_invariance_nki.py new file mode 100644 index 0000000..206f143 --- /dev/null +++ b/contributed/batch_invariance/test_batch_invariance_nki.py @@ -0,0 +1,423 @@ +""" +Minimal NKI Batch Invariance Test - Clean Implementation + +Tests if dynamic M tiling introduces non-determinism in matmul. +Based on NKI matmul example pattern. +""" + +import torch +import torch_neuronx +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl + + +@nki.jit +def matmul_m64(a, b): + """ + Matmul with M tiled at 64 + a: [M, 4096], b: [4096, 512] + Output: [M, 512] + + Works with any M that's divisible by 64 + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 64 + K_TILE = 128 + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + # Tile over M dimension + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Load a: [M_TILE, K_TILE] + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + # Load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + # Matmul + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +@nki.jit +def matmul_m128(a, b): + """ + Matmul with M tiled at 128 + a: [M, 4096], b: [4096, 512] + Output: [M, 512] + + Works with any M that's divisible by 128 + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + K_TILE = 128 + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + # Tile over M dimension + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Load a: [M_TILE, K_TILE] + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + # Load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + # Matmul + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +@nki.jit +def matmul_k64(a, b): + """ + Matmul with K tiled at 64 (different contraction tile size) + + This should produce DIFFERENT results than K_TILE=128 + because the reduction order changes! + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + K_TILE = 64 # DIFFERENT K tiling! + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Now we have TWICE as many K iterations (64 instead of 32) + for k in nl.affine_range(K // K_TILE): + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +@nki.jit +def matmul_sequential(a, b): + """ + Matmul using sequential_range instead of affine_range + + sequential_range forces sequential execution with loop-carried dependency. + Question: Does this affect determinism? + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + K_TILE = 128 + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Using sequential_range - tells compiler there's loop dependency + for k in nl.sequential_range(K // K_TILE): + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +@nki.jit +def matmul_m128_fp32(a, b): + """ + Matmul with M_TILE=128, but using float32 inputs + To compare precision differences vs bfloat16 + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + K_TILE = 128 + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + for k in nl.affine_range(K // K_TILE): + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +@nki.jit +def matmul_k64_fp32(a, b): + """ + Matmul with K_TILE=64, using float32 inputs + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + K_TILE = 64 + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + for k in nl.affine_range(K // K_TILE): + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result + + +def test_batch_invariance(): + """ + Comprehensive batch invariance testing suite + """ + B, D, N = 2048, 4096, 512 + + # Create test inputs on XLA device + device = 'xla' + a = torch.linspace(-100, 100, B*D, device=device).reshape(B, D).to(torch.bfloat16) + b = torch.linspace(-100, 100, D*N, device=device).reshape(D, N).to(torch.bfloat16) + + print("=" * 70) + print("TEST 1: Different M_TILE on same input") + print("=" * 70) + print(f"Input: [{B}, {D}] @ [{D}, {N}]") + print(f"M_TILE=64: {B//64} iterations over M, K_TILE=128") + print(f"M_TILE=128: {B//128} iterations over M, K_TILE=128") + print() + + c_m64 = matmul_m64(a, b) + c_m128 = matmul_m128(a, b) + + c_m64_cpu = c_m64.cpu() + c_m128_cpu = c_m128.cpu() + + print("Results:") + print(f" M_TILE=64 row[0]: {c_m64_cpu[0, :5]}") + print(f" M_TILE=128 row[0]: {c_m128_cpu[0, :5]}") + + diff1 = (c_m64_cpu - c_m128_cpu).abs().max() + print(f"\n Max difference: {diff1.item()}") + print(f" Bitwise identical: {diff1.item() == 0}") + + print("\n" + "=" * 70) + print("TEST 2: Thinking Machines scenario - varying M (batch size)") + print("=" * 70) + print("The real batch invariance test!") + print(f"Compute row 0 with M=128 vs M=2048") + print() + + a_small = a[:128, :] + c_small = matmul_m128(a_small, b) + c_full = matmul_m128(a, b) + + c_small_cpu = c_small.cpu() + c_full_cpu = c_full.cpu() + + print("Results:") + print(f" M=128 row[0]: {c_small_cpu[0, :5]}") + print(f" M=2048 row[0]: {c_full_cpu[0, :5]}") + + diff2 = (c_small_cpu[0] - c_full_cpu[0]).abs().max() + print(f"\n Max difference: {diff2.item()}") + print(f" Bitwise identical: {diff2.item() == 0}") + + print("\n" + "=" * 70) + print("TEST 3: Different K_TILE - Does reduction order matter?") + print("=" * 70) + print("K_TILE=128: 32 K iterations (accumulate chunks: 0, 128, 256, ...)") + print("K_TILE=64: 64 K iterations (accumulate chunks: 0, 64, 128, ...)") + print("Different accumulation order → different floating point results!") + print() + + c_k128 = matmul_m128(a, b) # K_TILE=128 + c_k64 = matmul_k64(a, b) # K_TILE=64 + + c_k128_cpu = c_k128.cpu() + c_k64_cpu = c_k64.cpu() + + print("Results:") + print(f" K_TILE=128 row[0]: {c_k128_cpu[0, :5]}") + print(f" K_TILE=64 row[0]: {c_k64_cpu[0, :5]}") + + diff3 = (c_k128_cpu - c_k64_cpu).abs().max() + print(f"\n Max difference: {diff3.item()}") + print(f" Are they different? {diff3.item() != 0}") + + if diff3.item() != 0: + print(" ✓ EXPECTED! Different K_TILE → different reduction order") + else: + print(" ✗ UNEXPECTED! K_TILE should matter for floating point") + + print("\n" + "=" * 70) + print("TEST 4: sequential_range vs affine_range") + print("=" * 70) + print("affine_range: parallel-friendly, allows loop optimizations") + print("sequential_range: forces sequential execution, loop dependency") + print("Question: Do they produce identical results?") + print() + + c_affine = matmul_m128(a, b) # Uses affine_range + c_sequential = matmul_sequential(a, b) # Uses sequential_range + + c_affine_cpu = c_affine.cpu() + c_sequential_cpu = c_sequential.cpu() + + print("Results:") + print(f" affine_range row[0]: {c_affine_cpu[0, :5]}") + print(f" sequential_range row[0]: {c_sequential_cpu[0, :5]}") + + diff4 = (c_affine_cpu - c_sequential_cpu).abs().max() + print(f"\n Max difference: {diff4.item()}") + print(f" Bitwise identical: {diff4.item() == 0}") + + if diff4.item() == 0: + print(" ✓ Loop iterator type doesn't affect determinism!") + else: + print(" ✗ sequential_range changes results!") + + print("\n" + "=" * 70) + print("TEST 5: Precision Test - bfloat16 vs float32") + print("=" * 70) + print("Does reduced precision (bfloat16) amplify K_TILE differences?") + print("bfloat16: 7 bits mantissa, ~2-3 decimal digits precision") + print("float32: 23 bits mantissa, ~7 decimal digits precision") + print() + + # Create float32 inputs + a_fp32 = torch.linspace(-100, 100, B*D, device=device).reshape(B, D).to(torch.float32) + b_fp32 = torch.linspace(-100, 100, D*N, device=device).reshape(D, N).to(torch.float32) + + # Run with different K_TILE on float32 + c_k128_fp32 = matmul_m128_fp32(a_fp32, b_fp32) + c_k64_fp32 = matmul_k64_fp32(a_fp32, b_fp32) + + c_k128_fp32_cpu = c_k128_fp32.cpu() + c_k64_fp32_cpu = c_k64_fp32.cpu() + + print("Results (float32):") + print(f" K_TILE=128 row[0]: {c_k128_fp32_cpu[0, :5]}") + print(f" K_TILE=64 row[0]: {c_k64_fp32_cpu[0, :5]}") + + diff5_fp32 = (c_k128_fp32_cpu - c_k64_fp32_cpu).abs().max() + print(f"\n Max difference (float32): {diff5_fp32.item()}") + + print("\nComparison:") + print(f" bfloat16 K_TILE diff: {diff3.item()}") + print(f" float32 K_TILE diff: {diff5_fp32.item()}") + print(f" Ratio (bf16/fp32): {diff3.item() / diff5_fp32.item():.2f}x") + + if diff5_fp32.item() < diff3.item(): + print(f"\n ✓ float32 reduces error by {diff3.item() / diff5_fp32.item():.1f}x!") + print(" Lower precision (bfloat16) amplifies accumulation order effects") + else: + print("\n ✗ Unexpected: float32 doesn't reduce error significantly") + + # Also check: Is the difference consistent across runs? + print("\n" + "=" * 70) + print("TEST 6: Consistency Check - Is K_TILE difference stable?") + print("=" * 70) + print("Running K_TILE test 3 times to verify determinism...") + print() + + diffs = [] + for run in range(3): + c_k128_run = matmul_m128(a, b) + c_k64_run = matmul_k64(a, b) + diff_run = (c_k128_run.cpu() - c_k64_run.cpu()).abs().max().item() + diffs.append(diff_run) + print(f" Run {run+1}: max diff = {diff_run}") + + if len(set(diffs)) == 1: + print(f"\n ✓ FULLY DETERMINISTIC! All runs: {diffs[0]}") + print(" The 256.0 difference is consistent and repeatable") + else: + print(f"\n ✗ Non-deterministic! Diffs vary: {diffs}") + + print("\n" + "=" * 70) + print("FINAL VERDICT") + print("=" * 70) + + print(f"\n1. M_TILE variation (64 vs 128): {'✓ INVARIANT' if diff1.item() == 0 else '✗ VARIANT'}") + print(f"2. M variation (128 vs 2048): {'✓ INVARIANT' if diff2.item() == 0 else '✗ VARIANT'}") + print(f"3. K_TILE variation (64 vs 128): {'✓ VARIANT (expected)' if diff3.item() != 0 else '✗ INVARIANT (unexpected)'}") + print(f"4. Loop iterator (affine vs seq): {'✓ INVARIANT' if diff4.item() == 0 else '✗ VARIANT'}") + print(f"5. Precision (bf16 vs fp32): {diff3.item():.1f} vs {diff5_fp32.item():.4f} ({diff3.item()/diff5_fp32.item():.1f}x)") + print(f"6. Consistency across runs: {'✓ DETERMINISTIC' if len(set(diffs)) == 1 else '✗ NON-DETERMINISTIC'}") + + if diff1.item() == 0 and diff2.item() == 0: + print("\n" + "🎉 " * 20) + print("NKI IS BATCH INVARIANT!") + print(" • M_TILE doesn't affect results (batch tiling invariant)") + print(" • M (batch size) doesn't affect results (true batch invariance)") + print(" • K_TILE DOES affect results (reduction order matters)") + print(f" • bfloat16 amplifies differences by {diff3.item()/diff5_fp32.item():.1f}x vs float32") + print(" • But for FIXED K_TILE, results are fully deterministic!") + print("🎉 " * 20) + else: + print("\n✗ Batch invariance NOT achieved") + + +if __name__ == "__main__": + test_batch_invariance() \ No newline at end of file From 7eff4e95ce6b21d6713d999dee877cbf68d7ac0a Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Wed, 15 Oct 2025 10:16:44 -0400 Subject: [PATCH 02/21] replicate rmsnorm --- contributed/batch_invariance/README.md | 326 ++++++++++---- .../batch_invariance/kernels/__init__.py | 0 .../kernels/matmul_batch_invariant.py | 57 +++ .../kernels/rmsnorm_batch_invariant.py | 82 ++++ .../kernels/rmsnorm_split_reduction.py | 104 +++++ .../batch_invariance/test_batch_invariance.py | 149 ++++++ .../test_batch_invariance_nki.py | 423 ------------------ 7 files changed, 620 insertions(+), 521 deletions(-) create mode 100644 contributed/batch_invariance/kernels/__init__.py create mode 100644 contributed/batch_invariance/kernels/matmul_batch_invariant.py create mode 100644 contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py create mode 100644 contributed/batch_invariance/kernels/rmsnorm_split_reduction.py create mode 100644 contributed/batch_invariance/test_batch_invariance.py delete mode 100644 contributed/batch_invariance/test_batch_invariance_nki.py diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index 69f11b3..26e6b36 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -1,141 +1,271 @@ # NKI Batch Invariance Test -Testing whether NKI's tile size constraints protect against batch-dependent non-determinism in matrix multiplication. +Demonstrating batch invariance principles in NKI (Neuron Kernel Interface), replicating findings from [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/). -## Hypothesis +## What is Batch Invariance? -**NKI achieves batch invariance by default due to hardware tile constraints.** +**Batch invariance** means that computing the same element in different batch sizes produces **identical numerical results**. The paper demonstrates that CUDA/PyTorch matrix multiplication is **NOT batch-invariant** due to dynamic optimization strategies that change based on batch size. -Unlike CUDA/PyTorch, where batch size can influence the K-dimension reduction strategy (e.g., switching to split-K for better parallelism when M is small), NKI's hardware constraints enforce fixed tile sizes that decouple batch size from reduction order. +## When Does Batch Variance Occur? -### Key Protection Mechanisms +Batch variance occurs when **ALL THREE conditions are met**: -1. **K is the reduction axis, not the batch axis (M)** - - Reduction happens over K (contraction dimension) - - M (batch) loop is outer, K loop is inner - - Changing M doesn't affect K iteration count +1. **Tiling the reduction dimension** (not parallelizable dimensions) + - MatMul: Tiling K (contraction dimension) ✓ + - RMSNorm: Tiling hidden dimension in split reduction ✓ + - RMSNorm: Tiling batch dimension ✗ (batch is parallelizable) -2. **Hardware constraints enforce fixed tile sizes** - - Tensor Engine limits: P-dim ≤ 128, free-dim ≤ 512 - - Forces compile-time constants (e.g., K_TILE=128) - - Prevents runtime adaptation based on batch size +2. **Iterative accumulation across tiles** (not atomic reductions) + - `c_psum += matmul(a_tile, b_tile)` ✓ Creates variance + - `nl.sum(entire_row)` ✗ Atomic, no variance -3. **Potential vulnerability: Split-K** - - NKI *could* split along K when M is small (like CUDA does) - - This would couple M and K reduction strategy - - Our tests verify this doesn't happen automatically +3. **Dynamic tile size based on input characteristics** + - CUDA: Adapts K strategy based on batch size ✓ + - NKI (fixed): `K_TILE = 128` always ✗ + - NKI (variant): `K_TILE = 64 if K <= 512 else 128` ✓ -## Test Design +## Test Environment -Replicated [Thinking Machines' batch invariance test](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/): +- **Instance**: `inf2.xlarge` (AWS Trainium) +- **AMI ID**: `ami-0ec4ab14b1c5a10f2` +- **AMI Name**: `Deep Learning AMI Neuron (Ubuntu 22.04) 20250919` +- **Compiler**: `neuronxcc-2.21.18209.0` +- **Framework**: NKI (Neuron Kernel Interface) -Instance_type: `inf2.xlarge` -AMI ID: `ami-0ec4ab14b1c5a10f2` -AMI NAME: `Deep Learning AMI Neuron (Ubuntu 22.04) 20250919` -```python -# CUDA shows non-determinism: -out1 = torch.mm(a[:1], b) # M=1 -out2 = torch.mm(a, b)[:1] # M=2048 -# Result: out1 ≠ out2 (diff: 1669.25) - -# NKI test: -out1 = matmul_nki(a[:128], b)[0] # M=128 -out2 = matmul_nki(a, b)[0] # M=2048 -# Result: out1 == out2 (diff: 0.0) ✓ -``` +## Test Suite Overview + +We test three kernel implementations: + +1. **MatMul with K_TILE variation** - Demonstrates reduction dimension tiling variance +2. **RMSNorm (standard)** - Demonstrates natural batch invariance with atomic reductions +3. **RMSNorm (split reduction)** - Demonstrates hidden dimension tiling variance + +Each test compares: +- **Invariant mode**: Fixed tile size (batch-invariant) +- **Variant mode**: Adaptive tile size (batch-variant) +- **Precision impact**: bfloat16 vs float32 ## Results -### Test 1: M_TILE Variation (64 vs 128) -``` -M_TILE=64 → Result: [9664., 9600., ...] -M_TILE=128 → Result: [9664., 9600., ...] -Max difference: 0.0 ✓ INVARIANT -``` -**Conclusion:** Batch tiling strategy doesn't affect results. +### Test 1: MatMul - K_TILE Variance -### Test 2: M (Batch Size) Variation (128 vs 2048) -``` -M=128 → Result: [9664., 9600., ...] -M=2048 → Result: [9664., 9600., ...] -Max difference: 0.0 ✓ INVARIANT -``` -**Conclusion:** True batch invariance achieved. Same element produces identical results regardless of batch size. +**Configuration**: M=128, K=512, N=512 -### Test 3: K_TILE Variation (64 vs 128) - Simulated Dynamic Tiling ``` -K_TILE=128 → Result: [9664., 9600., ...] (32 iterations) -K_TILE=64 → Result: [9664., 9600., ...] (64 iterations) -Max difference: 256.0 ✓ VARIANT (expected) +bfloat16: + K_TILE=128 (invariant): 4 accumulations over K dimension + K_TILE=64 (variant): 8 accumulations over K dimension + Max difference: 0.007812 + Result: DIFFER ✓ + +float32: + K_TILE=128 (invariant): 4 accumulations + K_TILE=64 (variant): 8 accumulations + Max difference: 0.000050 + Result: DIFFER ✓ + +Precision impact: bfloat16 error is 157x larger than float32 ``` -**Conclusion:** Reduction order matters. Different K_TILE → different accumulation order → different floating-point results. This simulates what CUDA does when it adapts K strategy based on batch size. -### Test 4: Loop Iterator (affine_range vs sequential_range) -``` -affine_range → Result: [9664., 9600., ...] -sequential_range → Result: [9664., 9600., ...] -Max difference: 0.0 ✓ INVARIANT -``` -**Conclusion:** Loop iterator type is a compiler hint; doesn't affect numerical output. +**Key Finding**: Different K_TILE sizes create different accumulation orders in the reduction: +- K_TILE=128: `((chunk0 + chunk1) + chunk2) + chunk3` (4 tiles) +- K_TILE=64: `(((((((ch0 + ch1) + ch2) + ch3) + ch4) + ch5) + ch6) + ch7)` (8 tiles) + +Due to floating-point associativity: `(a + b) + c ≠ a + (b + c)` + +### Test 2: RMSNorm (Standard) - Natural Batch Invariance + +**Configuration**: batch_size varies, hidden_dim=256 -### Test 5: Precision Impact (bfloat16 vs float32) ``` -bfloat16 K_TILE diff: 256.0 (2.67% relative error) -float32 K_TILE diff: 15.125 (0.091% relative error) -Amplification: 16.9x +Same 32 rows computed in: + - batch=32 context + - batch=128 context + +Result: MATCH ✓ (identical) +Max difference: 0.0 ``` -**Conclusion:** Lower precision amplifies accumulation order effects. bfloat16's 7-bit mantissa shows 17x larger differences than float32's 23-bit mantissa. -### Test 6: Consistency Check +**Key Finding**: RMSNorm is naturally batch-invariant because: +1. Each row computed independently (no inter-row dependencies) +2. Reduction is atomic: `nl.sum(in_square, axis=[1])` reduces entire hidden dimension at once +3. Batch tiling only affects parallelism, not computation order + +### Test 3: RMSNorm (Split Reduction) - Hidden Dimension Tiling Variance + +**Configuration**: batch_size=64, hidden_dim=512 + ``` -Run 1: 256.0 -Run 2: 256.0 -Run 3: 256.0 -✓ FULLY DETERMINISTIC +bfloat16: + HIDDEN_TILE=256 (invariant): 2 chunks, 1 accumulation + HIDDEN_TILE=128 (variant): 4 chunks, 3 accumulations + Max difference: 0.007812 + Result: DIFFER ✓ + +float32: + HIDDEN_TILE=256 (invariant): 2 chunks, 1 accumulation + HIDDEN_TILE=128 (variant): 4 chunks, 3 accumulations + Max difference: 0.000000 + Result: IDENTICAL + +Precision impact: Variance only visible in bfloat16 ``` -**Conclusion:** The K_TILE difference is consistent and repeatable, not random. + +**Key Finding**: Split reduction creates variance by tiling the **reduction dimension** (hidden_dim): +- Standard RMSNorm: `nl.sum(row)` - atomic, invariant +- Split RMSNorm: `sum(chunk0) + sum(chunk1) + sum(chunk2) + sum(chunk3)` - iterative, variant + +**Important**: Float32 precision is sufficient to make simple addition accumulation errors negligible, unlike multiply-accumulate in MatMul. ## Key Findings -### ✅ Hypothesis Confirmed +### 🎯 Core Principle: Reduction Dimension Tiling Creates Variance + +**Operations are naturally batch-invariant UNTIL:** -**NKI IS BATCH INVARIANT** -- M_TILE doesn't affect results (batch tiling invariant) -- M (batch size) doesn't affect results (true batch invariance) -- K_TILE DOES affect results (reduction order matters) -- But K_TILE is a compile-time constant → fully deterministic +1. ✅ You tile the **reduction dimension** (not parallelizable dimensions) +2. ✅ Tile size changes **dynamically** based on input characteristics +3. ✅ Operation uses **iterative accumulation** (not atomic reductions) -### 📊 Comparison: NKI vs CUDA +**Examples:** +- ❌ **No variance**: RMSNorm batch tiling - tiles parallelizable dimension (batch) +- ✅ **Creates variance**: MatMul K tiling - tiles reduction dimension with accumulation +- ✅ **Creates variance**: RMSNorm split reduction - tiles hidden dimension with accumulation -| Aspect | CUDA | NKI | -|--------|------|-----| -| Batch size affects K reduction? | ✗ Yes (split-K adaptation) | ✅ No (fixed K_TILE) | -| Run-to-run deterministic? | ✗ No (varies ~1669) | ✅ Yes (always identical) | -| K_TILE matters? | ✅ Yes | ✅ Yes | -| Tile size constraints? | Flexible | Hardware-enforced (≤128/512) | +### 📊 Precision Amplifies Variance -### 🔬 Why NKI Wins +| Operation | bfloat16 Error | float32 Error | Amplification | +|-----------|---------------|---------------|---------------| +| MatMul (K_TILE) | 0.007812 | 0.000050 | **157x** | +| RMSNorm Split (HIDDEN_TILE) | 0.007812 | ~0.000000 | Only visible in bfloat16 | -1. **M/K decoupling:** Batch loop (M) is outer, reduction loop (K) is inner. Changing batch size doesn't affect K iteration count. +**Critical Insight**: Reduced precision (bfloat16) amplifies tiling variance dramatically: +- **Multiply-accumulate** (MatMul): Errors compound quickly, visible in both precisions +- **Pure addition** (RMSNorm sum): Errors compound slowly, only visible in bfloat16 +- **Implication**: bfloat16 users need batch-invariant implementations more urgently -2. **Hardware constraints as a feature:** Tensor Engine limits force compile-time K_TILE constants, preventing runtime adaptation. +### 🔬 Replicating Paper Findings with NKI -3. **No automatic split-K:** NKI doesn't dynamically switch to split-K based on batch size. You'd need to write a separate kernel. +Our results directly replicate [Thinking Machines' findings](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/): -## Implications +**Paper's observation (CUDA):** +> "CUDA adapts K reduction strategy based on batch size, causing non-determinism" -**For LLM Inference:** -- Batch-invariant by default (no special kernels needed like Thinking Machines built for CUDA) -- Deterministic sampling at temperature=0 (if K_TILE is fixed) -- True on-policy RL possible (identical numerics between training and inference) +**Our NKI implementation:** +```python +# Batch-variant: Mimics CUDA's dynamic strategy +K_TILE = 64 if K <= 512 else 128 + +# Batch-invariant: Fixed strategy (paper's solution) +K_TILE = 128 # Always +``` -**Caveats:** -- K_TILE variation causes 2.67% relative error in bfloat16 (acceptable for most LLM use cases) -- Must use consistent K_TILE across kernels for bitwise reproducibility -- Lower precision (bfloat16) amplifies accumulation order effects 17x vs float32 +**Result**: Same variance pattern observed in NKI when we explicitly code dynamic tiling, confirming the paper's root cause analysis. + +### 🛡️ NKI's Natural Protection + +**Why NKI tends toward batch-invariance:** + +1. **Hardware constraints enforce constants** + - Tensor Engine limits: P-dim ≤ 128, free-dim ≤ 512 + - Encourages fixed compile-time tile sizes + - Makes dynamic adaptation less natural + +2. **Explicit control over tiling** + - Developers explicitly set K_TILE, HIDDEN_TILE, etc. + - No "magic" runtime optimization that varies strategy + - Batch-invariance is default unless explicitly coded otherwise + +3. **Atomic operations where possible** + - `nl.sum(entire_dimension)` is atomic - naturally invariant + - Only manual tiling creates variance + +## Implications for LLM Inference + +### ✅ Benefits + +1. **Deterministic inference** - Same outputs for temperature=0 sampling regardless of batch size +2. **On-policy RL** - Training and inference produce identical numerics +3. **Debugging** - Reproducible results across batch sizes simplifies debugging +4. **Cache coherence** - KV-cache values identical whether computed individually or batched + +### ⚠️ Requirements for Batch-Invariance + +1. **Fix reduction tile sizes** + ```python + # ❌ BAD: Dynamic tiling + K_TILE = 64 if K <= 512 else 128 + + # ✅ GOOD: Fixed tiling + K_TILE = 128 # Always + ``` + +2. **Use consistent precision** + - bfloat16 shows 157x larger variance than float32 + - Mixed precision can break invariance + +3. **Avoid split reductions when possible** + - Prefer atomic reductions: `nl.sum(entire_dimension)` + - If split necessary, use fixed tile sizes ## Conclusion -NKI's tile size constraints, enforced by hardware limitations, provide batch invariance as an inherent property rather than requiring specialized implementations. The decoupling of batch size (M) from reduction strategy (K_TILE) ensures that the same element produces identical results regardless of the batch it's computed in. +NKI naturally encourages batch-invariant implementations through: +- Hardware-enforced tile size constraints +- Explicit tiling control (no magic runtime optimization) +- Atomic reduction operations as primitives + +However, variance can still occur when: +- Manually implementing split reductions with dynamic tile sizes +- Using reduced precision (bfloat16) with iterative accumulation +- Adapting strategies based on input characteristics + +**Our findings directly replicate the Thinking Machines paper**: Batch variance stems from **dynamic tiling of reduction dimensions**, and the solution is **fixed tiling strategies**. NKI makes this easier by design, but developers must still be intentional about tile size choices, especially when using bfloat16 precision. + +## Running the Tests + +```bash +cd contributed/batch_invariance +python test_batch_invariance.py +``` + +**Expected Output:** +``` +================================================================================ +Testing MatMul batch invariance... + Testing with bfloat16: + Max difference between K_TILE strategies: 0.007812 + Results differ + Testing with float32: + Max difference between K_TILE strategies: 0.000050 + Results differ + Precision impact: bfloat16 error is 157x larger than float32 + +================================================================================ +Testing RMSNorm batch invariance... + First 32 rows: batch=32 vs batch=128: MATCH ✓ + ✓ RMSNorm is batch-invariant! + +================================================================================ +Testing RMSNorm with Split Reduction... + Testing with bfloat16: + Max difference between HIDDEN_TILE strategies: 0.007812 + Results differ + Testing with float32: + Max difference between HIDDEN_TILE strategies: 0.000000 + Results identical +``` + +## Files + +- `kernels/matmul_batch_invariant.py` - MatMul with configurable K_TILE +- `kernels/rmsnorm_batch_invariant.py` - Standard RMSNorm (atomic reduction) +- `kernels/rmsnorm_split_reduction.py` - RMSNorm with split reduction (demonstrates variance) +- `test_batch_invariance.py` - Comprehensive test suite +- `README.md` - This document + +## References -**Bottom line:** CUDA varies K reduction order *unpredictably* based on batch size. NKI keeps it *fixed* based on compile-time K_TILE. That's the win. +- [Thinking Machines: Defeating Nondeterminism in LLM Inference](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) +- [AWS Neuron Documentation](https://awsdocs-neuron.readthedocs-hosted.com/) +- [NKI Programming Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/) diff --git a/contributed/batch_invariance/kernels/__init__.py b/contributed/batch_invariance/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py new file mode 100644 index 0000000..60f6918 --- /dev/null +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -0,0 +1,57 @@ +""" +Batch-Invariant MatMul Kernel + +This kernel demonstrates batch invariance in matrix multiplication by controlling +the M-dimension tiling strategy. +""" + +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl + + +@nki.jit +def nki_matmul_kernel(a, b, batch_invariant=True): + """ + Matrix multiplication with batch invariance parameter + + batch_invariant=True: Uses K_TILE=128 + batch_invariant=False: Uses K_TILE=64 + + This demonstrates how different K tiling affects numerical results. + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + + # ONLY DIFFERENCE: K_TILE strategy + if batch_invariant: + K_TILE = 128 # Always hardcoded + else: + K_TILE = 64 if K <= 512 else 128 # Adaptive + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + # Use EXACT same logic as working matmul_m128 + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Load a: [M_TILE, K_TILE] + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + # Load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + # Matmul + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result \ No newline at end of file diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py new file mode 100644 index 0000000..1b2dfbc --- /dev/null +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -0,0 +1,82 @@ +""" +Batch-Invariant RMSNorm Kernel +""" + +import math +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl + + +@nki.jit +def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): + """ + RMSNorm with batch invariance parameter + + This demonstrates TRUE batch invariance testing: + - batch_invariant=True: Always uses tile_size=128 (same strategy regardless of batch) + - batch_invariant=False: Adapts tile_size based on batch size (different strategies) + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + # Make sure shapes match + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + + # CRITICAL: Tile size based on BATCH SIZE (not hidden_dim) + # This is what creates batch variance! + if batch_invariant: + # INVARIANT: Fixed strategy regardless of batch size + tile_size = 128 + else: + # VARIANT: Strategy changes based on batch size + # Small batches get smaller tiles -> different processing pattern + if num_rows <= 64: + tile_size = 32 # Small batch: smaller tiles + else: + tile_size = 128 # Large batch: larger tiles + + # Generate tensor indices based on tile_size + ix = nl.arange(tile_size)[:, None] + iw = nl.arange(1)[:, None] + iy = nl.arange(hidden_dim)[None, :] + + # Load RMSNorm weight once + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy]) + + # Process tile_size rows at a time + for i in nl.affine_range(math.ceil(num_rows / tile_size)): + + # Load input data from external memory to on-chip memory + a_tile = nl.load(a_tensor[i * tile_size + ix, iy], + mask=(i * tile_size + ix < num_rows)) + + # Compute element-wise square of a_tensor + in_square = nl.square(a_tile) + + # Calculate sum of squared elements, along last dimension + square_sum = nl.sum(in_square, axis=[1]) + + # Scale and get a reciprocal + mean = square_sum / hidden_dim + + # Take square root of mean and then reciprocal with rsqrt API + rms_reciprocal = nl.rsqrt(mean) + + # Scale the input tensor + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Broadcast weight along first axis to match tensor shape + g_bcast = g_tile.broadcast_to((tile_size, hidden_dim)) + + # Multiply with the RMSNorm weight + out_tile[...] = nl.multiply(out_tile, g_bcast, + mask=(i * tile_size + ix < num_rows)) + + # store the results back to external memory + nl.store(out_tensor[i * tile_size + ix, iy], value=out_tile, + mask=(i * tile_size + ix < num_rows)) + + return out_tensor \ No newline at end of file diff --git a/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py b/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py new file mode 100644 index 0000000..524ec5c --- /dev/null +++ b/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py @@ -0,0 +1,104 @@ +""" +RMSNorm with Split Reduction - Demonstrates TRUE Batch Variance + +This kernel tiles the HIDDEN DIMENSION (reduction axis) instead of just the batch dimension. +This creates different accumulation orders and breaks batch-invariance! +""" + +import math +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl + + +@nki.jit +def nki_rmsnorm_split_reduction(a_tensor, g_tensor, batch_invariant=True): + """ + RMSNorm with split reduction along hidden dimension + + batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) + + This demonstrates REAL batch variance because different tile sizes + change the order of floating-point additions during reduction. + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + BATCH_TILE = 128 + + # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) + # Different sizes = different number of accumulations = variance! + if batch_invariant: + HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) + else: + HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + + ix = nl.arange(BATCH_TILE)[:, None] + iw = nl.arange(1)[:, None] + + # Process batch in tiles + for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): + + # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks + # Use PSUM for accumulation (always float32 internally) + partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) + + # Iterate over hidden dimension in chunks + num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) + for h in nl.affine_range(num_hidden_tiles): + h_start = h * HIDDEN_TILE + + # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) + iy = nl.arange(HIDDEN_TILE)[None, :] + + # Create mask for valid hidden indices + valid_mask = ((i * BATCH_TILE + ix < num_rows) & + (h * HIDDEN_TILE + iy < hidden_dim)) + + # Load a CHUNK of the hidden dimension with proper indexing + a_chunk = nl.load( + a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + mask=valid_mask + ) + + # Square this chunk + in_square_chunk = nl.square(a_chunk) + + # Reduce this chunk (sum along hidden dimension) + # Mask ensures we only sum valid elements + chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, + mask=valid_mask) + + # ACCUMULATE: This is where variance enters! + # Different HIDDEN_TILE sizes mean different number of additions + partial_square_sum += chunk_sum + + # Compute mean and RMS + mean = partial_square_sum / hidden_dim + rms_reciprocal = nl.rsqrt(mean) + + # Now load full row for normalization + iy_full = nl.arange(hidden_dim)[None, :] + a_tile = nl.load( + a_tensor[i * BATCH_TILE + ix, iy_full], + mask=(i * BATCH_TILE + ix < num_rows) + ) + + # Normalize by RMS + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Apply weight + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) + g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) + out_tile = nl.multiply(out_tile, g_bcast, + mask=(i * BATCH_TILE + ix < num_rows)) + + # Store result + nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, + mask=(i * BATCH_TILE + ix < num_rows)) + + return out_tensor diff --git a/contributed/batch_invariance/test_batch_invariance.py b/contributed/batch_invariance/test_batch_invariance.py new file mode 100644 index 0000000..c469da7 --- /dev/null +++ b/contributed/batch_invariance/test_batch_invariance.py @@ -0,0 +1,149 @@ +""" +Simple Batch Invariance Test +""" + +import torch +import torch_neuronx +import numpy as np +from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel +from kernels.rmsnorm_split_reduction import nki_rmsnorm_split_reduction +from kernels.matmul_batch_invariant import nki_matmul_kernel as matmul_batch_invariant + + +def test_matmul(): + """MatMul test showing K_TILE effect and precision impact""" + print("Testing MatMul batch invariance...") + + device = 'xla' + M, K, N = 128, 512, 512 # K=512 triggers different behavior! + + print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") + print() + + # Test with bfloat16 + print(" Testing with bfloat16:") + a_bf16 = torch.linspace(-1, 1, M * K, device=device).reshape(M, K).to(torch.bfloat16) + b_bf16 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.bfloat16) + + result_inv_bf16 = matmul_batch_invariant(a_bf16, b_bf16, batch_invariant=True) # K_TILE=128 + result_var_bf16 = matmul_batch_invariant(a_bf16, b_bf16, batch_invariant=False) # K_TILE=64 + + diff_bf16 = torch.max(torch.abs(result_inv_bf16 - result_var_bf16)).item() + print(f" Max difference between K_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + + print() + + # Test with float32 + print(" Testing with float32:") + a_f32 = torch.linspace(-1, 1, M * K, device=device).reshape(M, K).to(torch.float32) + b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) + + result_inv_f32 = matmul_batch_invariant(a_f32, b_f32, batch_invariant=True) # K_TILE=128 + result_var_f32 = matmul_batch_invariant(a_f32, b_f32, batch_invariant=False) # K_TILE=64 + + diff_f32 = torch.max(torch.abs(result_inv_f32 - result_var_f32)).item() + print(f" Max difference between K_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + + print() + print(f" Precision impact: bfloat16 error is {diff_bf16/diff_f32 if diff_f32 > 0 else 'N/A'}x larger than float32") + print(f" This demonstrates how reduced precision amplifies tiling strategy effects") + + +def test_rmsnorm(): + """RMSNorm demonstrates batch INVARIANCE (not variance)""" + print("Testing RMSNorm batch invariance...") + + device = 'xla' + hidden_dim = 256 + + # Create a large input with many rows + large_batch = 128 + a_large = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) + g = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) + + # Test the SAME 32 rows in different batch contexts + a_small = a_large[:32, :] + + # Process as small batch (32 rows) + result_small = nki_rmsnorm_kernel(a_small, g, batch_invariant=True) + + # Process as part of large batch (128 rows) + result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=True) + + # Compare the SAME rows + match = torch.allclose(result_small, result_large[:32], atol=1e-6) + print(f" First 32 rows: batch=32 vs batch=128: {'MATCH ✓' if match else 'DIFFER ✗'}") + + if match: + print(f" ✓ RMSNorm is batch-invariant!") + print(f" Each row computed independently, reduction is atomic") + print(f" Tile size only affects parallelism, not computation order") + + +def test_rmsnorm_split_reduction(): + """RMSNorm with SPLIT REDUCTION demonstrates TRUE batch VARIANCE""" + print("Testing RMSNorm with Split Reduction...") + print(" (Tiling the HIDDEN dimension creates different accumulation orders)") + + device = 'xla' + hidden_dim = 512 # Use 512 to see clear difference + batch_size = 64 + + print(f" hidden_dim={hidden_dim}") + print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") + print(f" batch_invariant=False: HIDDEN_TILE=128 (4 chunks, 3 accumulations)") + print() + + # Test with bfloat16 + print(" Testing with bfloat16:") + a_bf16 = torch.linspace(-1, 1, batch_size * hidden_dim, device=device).reshape(batch_size, hidden_dim).to(torch.bfloat16) + g_bf16 = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) + + result_inv_bf16 = nki_rmsnorm_split_reduction(a_bf16, g_bf16, batch_invariant=True) # HIDDEN_TILE=256 + result_var_bf16 = nki_rmsnorm_split_reduction(a_bf16, g_bf16, batch_invariant=False) # HIDDEN_TILE=128 + + diff_bf16 = torch.max(torch.abs(result_inv_bf16 - result_var_bf16)).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + + print() + + # Test with float32 + print(" Testing with float32:") + a_f32 = torch.linspace(-1, 1, batch_size * hidden_dim, device=device).reshape(batch_size, hidden_dim).to(torch.float32) + g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) + + result_inv_f32 = nki_rmsnorm_split_reduction(a_f32, g_f32, batch_invariant=True) # HIDDEN_TILE=256 + result_var_f32 = nki_rmsnorm_split_reduction(a_f32, g_f32, batch_invariant=False) # HIDDEN_TILE=128 + + diff_f32 = torch.max(torch.abs(result_inv_f32 - result_var_f32)).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + + print() + print(f" Precision impact: bfloat16 error is {diff_bf16/diff_f32 if diff_f32 > 0 else 'N/A'}x larger than float32") + print(f" ✓ Split reduction creates batch variance in BOTH precisions!") + print(f" Different hidden tile sizes → different accumulation order") + print(f" This is analogous to MatMul's K_TILE effect") + + +if __name__ == "__main__": + print("Batch Invariance Test") + print("=" * 80) + + test_matmul() + print() + print("=" * 80) + test_rmsnorm() + print() + print("=" * 80) + test_rmsnorm_split_reduction() + + print("\n" + "=" * 80) + print("SUMMARY:") + print(" • MatMul: K_TILE variance - different reduction chunking") + print(" • RMSNorm (standard): Batch-invariant - atomic reduction") + print(" • RMSNorm (split): HIDDEN_TILE variance - reduction chunking") + print("\nDone!") \ No newline at end of file diff --git a/contributed/batch_invariance/test_batch_invariance_nki.py b/contributed/batch_invariance/test_batch_invariance_nki.py deleted file mode 100644 index 206f143..0000000 --- a/contributed/batch_invariance/test_batch_invariance_nki.py +++ /dev/null @@ -1,423 +0,0 @@ -""" -Minimal NKI Batch Invariance Test - Clean Implementation - -Tests if dynamic M tiling introduces non-determinism in matmul. -Based on NKI matmul example pattern. -""" - -import torch -import torch_neuronx -import neuronxcc.nki as nki -import neuronxcc.nki.language as nl - - -@nki.jit -def matmul_m64(a, b): - """ - Matmul with M tiled at 64 - a: [M, 4096], b: [4096, 512] - Output: [M, 512] - - Works with any M that's divisible by 64 - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 64 - K_TILE = 128 - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - # Tile over M dimension - for m in nl.affine_range(M // M_TILE): - # Accumulator for this M chunk - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - # Reduction over K - for k in nl.affine_range(K // K_TILE): - # Load a: [M_TILE, K_TILE] - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - # Load b: [K_TILE, N] - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - # Matmul - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - # Store this M chunk - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -@nki.jit -def matmul_m128(a, b): - """ - Matmul with M tiled at 128 - a: [M, 4096], b: [4096, 512] - Output: [M, 512] - - Works with any M that's divisible by 128 - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 128 - K_TILE = 128 - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - # Tile over M dimension - for m in nl.affine_range(M // M_TILE): - # Accumulator for this M chunk - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - # Reduction over K - for k in nl.affine_range(K // K_TILE): - # Load a: [M_TILE, K_TILE] - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - # Load b: [K_TILE, N] - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - # Matmul - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - # Store this M chunk - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -@nki.jit -def matmul_k64(a, b): - """ - Matmul with K tiled at 64 (different contraction tile size) - - This should produce DIFFERENT results than K_TILE=128 - because the reduction order changes! - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 128 - K_TILE = 64 # DIFFERENT K tiling! - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - for m in nl.affine_range(M // M_TILE): - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - # Now we have TWICE as many K iterations (64 instead of 32) - for k in nl.affine_range(K // K_TILE): - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -@nki.jit -def matmul_sequential(a, b): - """ - Matmul using sequential_range instead of affine_range - - sequential_range forces sequential execution with loop-carried dependency. - Question: Does this affect determinism? - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 128 - K_TILE = 128 - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - for m in nl.affine_range(M // M_TILE): - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - # Using sequential_range - tells compiler there's loop dependency - for k in nl.sequential_range(K // K_TILE): - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -@nki.jit -def matmul_m128_fp32(a, b): - """ - Matmul with M_TILE=128, but using float32 inputs - To compare precision differences vs bfloat16 - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 128 - K_TILE = 128 - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - for m in nl.affine_range(M // M_TILE): - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - for k in nl.affine_range(K // K_TILE): - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -@nki.jit -def matmul_k64_fp32(a, b): - """ - Matmul with K_TILE=64, using float32 inputs - """ - M, K = a.shape - N = b.shape[1] - M_TILE = 128 - K_TILE = 64 - - result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - - for m in nl.affine_range(M // M_TILE): - c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) - - for k in nl.affine_range(K // K_TILE): - i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] - a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) - - i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) - - c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) - - i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] - c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - - return result - - -def test_batch_invariance(): - """ - Comprehensive batch invariance testing suite - """ - B, D, N = 2048, 4096, 512 - - # Create test inputs on XLA device - device = 'xla' - a = torch.linspace(-100, 100, B*D, device=device).reshape(B, D).to(torch.bfloat16) - b = torch.linspace(-100, 100, D*N, device=device).reshape(D, N).to(torch.bfloat16) - - print("=" * 70) - print("TEST 1: Different M_TILE on same input") - print("=" * 70) - print(f"Input: [{B}, {D}] @ [{D}, {N}]") - print(f"M_TILE=64: {B//64} iterations over M, K_TILE=128") - print(f"M_TILE=128: {B//128} iterations over M, K_TILE=128") - print() - - c_m64 = matmul_m64(a, b) - c_m128 = matmul_m128(a, b) - - c_m64_cpu = c_m64.cpu() - c_m128_cpu = c_m128.cpu() - - print("Results:") - print(f" M_TILE=64 row[0]: {c_m64_cpu[0, :5]}") - print(f" M_TILE=128 row[0]: {c_m128_cpu[0, :5]}") - - diff1 = (c_m64_cpu - c_m128_cpu).abs().max() - print(f"\n Max difference: {diff1.item()}") - print(f" Bitwise identical: {diff1.item() == 0}") - - print("\n" + "=" * 70) - print("TEST 2: Thinking Machines scenario - varying M (batch size)") - print("=" * 70) - print("The real batch invariance test!") - print(f"Compute row 0 with M=128 vs M=2048") - print() - - a_small = a[:128, :] - c_small = matmul_m128(a_small, b) - c_full = matmul_m128(a, b) - - c_small_cpu = c_small.cpu() - c_full_cpu = c_full.cpu() - - print("Results:") - print(f" M=128 row[0]: {c_small_cpu[0, :5]}") - print(f" M=2048 row[0]: {c_full_cpu[0, :5]}") - - diff2 = (c_small_cpu[0] - c_full_cpu[0]).abs().max() - print(f"\n Max difference: {diff2.item()}") - print(f" Bitwise identical: {diff2.item() == 0}") - - print("\n" + "=" * 70) - print("TEST 3: Different K_TILE - Does reduction order matter?") - print("=" * 70) - print("K_TILE=128: 32 K iterations (accumulate chunks: 0, 128, 256, ...)") - print("K_TILE=64: 64 K iterations (accumulate chunks: 0, 64, 128, ...)") - print("Different accumulation order → different floating point results!") - print() - - c_k128 = matmul_m128(a, b) # K_TILE=128 - c_k64 = matmul_k64(a, b) # K_TILE=64 - - c_k128_cpu = c_k128.cpu() - c_k64_cpu = c_k64.cpu() - - print("Results:") - print(f" K_TILE=128 row[0]: {c_k128_cpu[0, :5]}") - print(f" K_TILE=64 row[0]: {c_k64_cpu[0, :5]}") - - diff3 = (c_k128_cpu - c_k64_cpu).abs().max() - print(f"\n Max difference: {diff3.item()}") - print(f" Are they different? {diff3.item() != 0}") - - if diff3.item() != 0: - print(" ✓ EXPECTED! Different K_TILE → different reduction order") - else: - print(" ✗ UNEXPECTED! K_TILE should matter for floating point") - - print("\n" + "=" * 70) - print("TEST 4: sequential_range vs affine_range") - print("=" * 70) - print("affine_range: parallel-friendly, allows loop optimizations") - print("sequential_range: forces sequential execution, loop dependency") - print("Question: Do they produce identical results?") - print() - - c_affine = matmul_m128(a, b) # Uses affine_range - c_sequential = matmul_sequential(a, b) # Uses sequential_range - - c_affine_cpu = c_affine.cpu() - c_sequential_cpu = c_sequential.cpu() - - print("Results:") - print(f" affine_range row[0]: {c_affine_cpu[0, :5]}") - print(f" sequential_range row[0]: {c_sequential_cpu[0, :5]}") - - diff4 = (c_affine_cpu - c_sequential_cpu).abs().max() - print(f"\n Max difference: {diff4.item()}") - print(f" Bitwise identical: {diff4.item() == 0}") - - if diff4.item() == 0: - print(" ✓ Loop iterator type doesn't affect determinism!") - else: - print(" ✗ sequential_range changes results!") - - print("\n" + "=" * 70) - print("TEST 5: Precision Test - bfloat16 vs float32") - print("=" * 70) - print("Does reduced precision (bfloat16) amplify K_TILE differences?") - print("bfloat16: 7 bits mantissa, ~2-3 decimal digits precision") - print("float32: 23 bits mantissa, ~7 decimal digits precision") - print() - - # Create float32 inputs - a_fp32 = torch.linspace(-100, 100, B*D, device=device).reshape(B, D).to(torch.float32) - b_fp32 = torch.linspace(-100, 100, D*N, device=device).reshape(D, N).to(torch.float32) - - # Run with different K_TILE on float32 - c_k128_fp32 = matmul_m128_fp32(a_fp32, b_fp32) - c_k64_fp32 = matmul_k64_fp32(a_fp32, b_fp32) - - c_k128_fp32_cpu = c_k128_fp32.cpu() - c_k64_fp32_cpu = c_k64_fp32.cpu() - - print("Results (float32):") - print(f" K_TILE=128 row[0]: {c_k128_fp32_cpu[0, :5]}") - print(f" K_TILE=64 row[0]: {c_k64_fp32_cpu[0, :5]}") - - diff5_fp32 = (c_k128_fp32_cpu - c_k64_fp32_cpu).abs().max() - print(f"\n Max difference (float32): {diff5_fp32.item()}") - - print("\nComparison:") - print(f" bfloat16 K_TILE diff: {diff3.item()}") - print(f" float32 K_TILE diff: {diff5_fp32.item()}") - print(f" Ratio (bf16/fp32): {diff3.item() / diff5_fp32.item():.2f}x") - - if diff5_fp32.item() < diff3.item(): - print(f"\n ✓ float32 reduces error by {diff3.item() / diff5_fp32.item():.1f}x!") - print(" Lower precision (bfloat16) amplifies accumulation order effects") - else: - print("\n ✗ Unexpected: float32 doesn't reduce error significantly") - - # Also check: Is the difference consistent across runs? - print("\n" + "=" * 70) - print("TEST 6: Consistency Check - Is K_TILE difference stable?") - print("=" * 70) - print("Running K_TILE test 3 times to verify determinism...") - print() - - diffs = [] - for run in range(3): - c_k128_run = matmul_m128(a, b) - c_k64_run = matmul_k64(a, b) - diff_run = (c_k128_run.cpu() - c_k64_run.cpu()).abs().max().item() - diffs.append(diff_run) - print(f" Run {run+1}: max diff = {diff_run}") - - if len(set(diffs)) == 1: - print(f"\n ✓ FULLY DETERMINISTIC! All runs: {diffs[0]}") - print(" The 256.0 difference is consistent and repeatable") - else: - print(f"\n ✗ Non-deterministic! Diffs vary: {diffs}") - - print("\n" + "=" * 70) - print("FINAL VERDICT") - print("=" * 70) - - print(f"\n1. M_TILE variation (64 vs 128): {'✓ INVARIANT' if diff1.item() == 0 else '✗ VARIANT'}") - print(f"2. M variation (128 vs 2048): {'✓ INVARIANT' if diff2.item() == 0 else '✗ VARIANT'}") - print(f"3. K_TILE variation (64 vs 128): {'✓ VARIANT (expected)' if diff3.item() != 0 else '✗ INVARIANT (unexpected)'}") - print(f"4. Loop iterator (affine vs seq): {'✓ INVARIANT' if diff4.item() == 0 else '✗ VARIANT'}") - print(f"5. Precision (bf16 vs fp32): {diff3.item():.1f} vs {diff5_fp32.item():.4f} ({diff3.item()/diff5_fp32.item():.1f}x)") - print(f"6. Consistency across runs: {'✓ DETERMINISTIC' if len(set(diffs)) == 1 else '✗ NON-DETERMINISTIC'}") - - if diff1.item() == 0 and diff2.item() == 0: - print("\n" + "🎉 " * 20) - print("NKI IS BATCH INVARIANT!") - print(" • M_TILE doesn't affect results (batch tiling invariant)") - print(" • M (batch size) doesn't affect results (true batch invariance)") - print(" • K_TILE DOES affect results (reduction order matters)") - print(f" • bfloat16 amplifies differences by {diff3.item()/diff5_fp32.item():.1f}x vs float32") - print(" • But for FIXED K_TILE, results are fully deterministic!") - print("🎉 " * 20) - else: - print("\n✗ Batch invariance NOT achieved") - - -if __name__ == "__main__": - test_batch_invariance() \ No newline at end of file From a5f821d74efaedacd2276c9c0f182a061f5e8806 Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Wed, 15 Oct 2025 10:22:59 -0400 Subject: [PATCH 03/21] replicate rmsnorm --- .../batch_invariance/kernels/matmul_batch_invariant.py | 1 - .../batch_invariance/kernels/rmsnorm_batch_invariant.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py index 60f6918..7e52b09 100644 --- a/contributed/batch_invariance/kernels/matmul_batch_invariant.py +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -31,7 +31,6 @@ def nki_matmul_kernel(a, b, batch_invariant=True): result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) - # Use EXACT same logic as working matmul_m128 for m in nl.affine_range(M // M_TILE): # Accumulator for this M chunk c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index 1b2dfbc..4917eae 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -12,9 +12,10 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): """ RMSNorm with batch invariance parameter - This demonstrates TRUE batch invariance testing: + This demonstrates batch invariance testing: - batch_invariant=True: Always uses tile_size=128 (same strategy regardless of batch) - batch_invariant=False: Adapts tile_size based on batch size (different strategies) + - This shows that varying the tiling strategy based on batch size does NOT affect results as we are not reducing across the batch dimension """ out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm) @@ -25,13 +26,11 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): num_rows = a_tensor.shape[0] hidden_dim = a_tensor.shape[1] - # CRITICAL: Tile size based on BATCH SIZE (not hidden_dim) - # This is what creates batch variance! if batch_invariant: # INVARIANT: Fixed strategy regardless of batch size tile_size = 128 else: - # VARIANT: Strategy changes based on batch size + # Also INVARIANT: Strategy changes based on batch size # Small batches get smaller tiles -> different processing pattern if num_rows <= 64: tile_size = 32 # Small batch: smaller tiles From 16cd709c3a8e06829a699813c21ab097f26b35a3 Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Wed, 15 Oct 2025 10:29:46 -0400 Subject: [PATCH 04/21] replicate rmsnorm --- contributed/batch_invariance/README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index 26e6b36..d9a5dd7 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -13,7 +13,6 @@ Batch variance occurs when **ALL THREE conditions are met**: 1. **Tiling the reduction dimension** (not parallelizable dimensions) - MatMul: Tiling K (contraction dimension) ✓ - RMSNorm: Tiling hidden dimension in split reduction ✓ - - RMSNorm: Tiling batch dimension ✗ (batch is parallelizable) 2. **Iterative accumulation across tiles** (not atomic reductions) - `c_psum += matmul(a_tile, b_tile)` ✓ Creates variance @@ -86,10 +85,10 @@ Result: MATCH ✓ (identical) Max difference: 0.0 ``` -**Key Finding**: RMSNorm is naturally batch-invariant because: -1. Each row computed independently (no inter-row dependencies) -2. Reduction is atomic: `nl.sum(in_square, axis=[1])` reduces entire hidden dimension at once -3. Batch tiling only affects parallelism, not computation order +**RMSNorm remains batch-invariant UNTIL you:** +- Tile the **hidden dimension** (the reduction axis) instead of the batch dimension +- Make that tile size **dynamic** based on input characteristics +- Use **iterative accumulation** across hidden dimension chunks (see Test 3 for this scenario) ### Test 3: RMSNorm (Split Reduction) - Hidden Dimension Tiling Variance @@ -108,14 +107,14 @@ float32: Max difference: 0.000000 Result: IDENTICAL -Precision impact: Variance only visible in bfloat16 +Precision impact: Variance only visible in bfloat16 for this test ``` **Key Finding**: Split reduction creates variance by tiling the **reduction dimension** (hidden_dim): - Standard RMSNorm: `nl.sum(row)` - atomic, invariant - Split RMSNorm: `sum(chunk0) + sum(chunk1) + sum(chunk2) + sum(chunk3)` - iterative, variant -**Important**: Float32 precision is sufficient to make simple addition accumulation errors negligible, unlike multiply-accumulate in MatMul. +**Important**: Float32 precision may be sufficient to make simple addition accumulation errors negligible, unlike multiply-accumulate in MatMul. ## Key Findings @@ -142,7 +141,7 @@ Precision impact: Variance only visible in bfloat16 **Critical Insight**: Reduced precision (bfloat16) amplifies tiling variance dramatically: - **Multiply-accumulate** (MatMul): Errors compound quickly, visible in both precisions - **Pure addition** (RMSNorm sum): Errors compound slowly, only visible in bfloat16 -- **Implication**: bfloat16 users need batch-invariant implementations more urgently +- **Implication**: bfloat16 sees more extreme batch variance ### 🔬 Replicating Paper Findings with NKI @@ -220,7 +219,7 @@ However, variance can still occur when: - Using reduced precision (bfloat16) with iterative accumulation - Adapting strategies based on input characteristics -**Our findings directly replicate the Thinking Machines paper**: Batch variance stems from **dynamic tiling of reduction dimensions**, and the solution is **fixed tiling strategies**. NKI makes this easier by design, but developers must still be intentional about tile size choices, especially when using bfloat16 precision. +**My findings directly replicate the Thinking Machines paper**: Batch variance stems from **dynamic tiling of reduction dimensions**, and the solution is **fixed tiling strategies**. NKI makes this easier by design, but developers must still be intentional about tile size choices, especially when using bfloat16 precision. ## Running the Tests From 24e0dd7a0b7d4bdb5435079619b3e2d1c648c76c Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Wed, 15 Oct 2025 11:52:16 -0400 Subject: [PATCH 05/21] add mermaid --- contributed/batch_invariance/README.md | 56 ++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index d9a5dd7..0c28b5e 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -23,6 +23,62 @@ Batch variance occurs when **ALL THREE conditions are met**: - NKI (fixed): `K_TILE = 128` always ✗ - NKI (variant): `K_TILE = 64 if K <= 512 else 128` ✓ +```mermaid +flowchart TD + Start[Input Tensor: batch_size x hidden_dim 1024] --> CheckBatch{What is batch_size?} + + CheckBatch -->|batch < 64| SmallBatch[Small Batch Strategy] + CheckBatch -->|64 ≤ batch < 128| MediumBatch[Medium Batch Strategy] + CheckBatch -->|batch ≥ 128| LargeBatch[Large Batch Strategy] + + SmallBatch --> TileSmall[TILE_SIZE = 64] + MediumBatch --> TileMedium[TILE_SIZE = 128] + LargeBatch --> TileLarge[TILE_SIZE = 256] + + TileSmall --> ChunkSmall[Split hidden_dim into 16 chunks] + TileMedium --> ChunkMedium[Split hidden_dim into 8 chunks] + TileLarge --> ChunkLarge[Split hidden_dim into 4 chunks] + + ChunkSmall --> ReduceSmall[Reduce each chunk:
sum elements 0:64
sum elements 64:128
... 16 partial sums] + ChunkMedium --> ReduceMedium[Reduce each chunk:
sum elements 0:128
sum elements 128:256
... 8 partial sums] + ChunkLarge --> ReduceLarge[Reduce each chunk:
sum elements 0:256
sum elements 256:512
... 4 partial sums] + + ReduceSmall --> AccumSmall[Accumulate 16 partials:
p1 + p2 = t1
t1 + p3 = t2
... 15 additions] + ReduceMedium --> AccumMedium[Accumulate 8 partials:
p1 + p2 = t1
t1 + p3 = t2
... 7 additions] + ReduceLarge --> AccumLarge[Accumulate 4 partials:
p1 + p2 = t1
t1 + p3 = t2
... 3 additions] + + AccumSmall --> ResultSmall[result_small
15 rounding errors] + AccumMedium --> ResultMedium[result_medium
7 rounding errors] + AccumLarge --> ResultLarge[result_large
3 rounding errors] + + ResultSmall --> Compare{Compare Results} + ResultMedium --> Compare + ResultLarge --> Compare + + Compare --> NotEqual[❌ result_small ≠ result_medium ≠ result_large
Different accumulation orders
Different floating-point rounding
NON-DETERMINISTIC] + + NotEqual --> Problem[🔥 PROBLEM: Same input data,
different batch sizes yield
different numerical results!] + + Problem --> Solution[✅ SOLUTION: Hardcode TILE_SIZE] + + Solution --> FixedTile[TILE_SIZE = 128 always] + FixedTile --> FixedChunks[Always 8 chunks
Always 7 accumulations
for ALL batch sizes] + FixedChunks --> Deterministic[✅ DETERMINISTIC RESULTS
batch=32: 8 chunks, 7 adds
batch=96: 8 chunks, 7 adds
batch=256: 8 chunks, 7 adds] + + style Start fill:#e3f2fd + style CheckBatch fill:#fff3e0 + style SmallBatch fill:#ffebee + style MediumBatch fill:#e8eaf6 + style LargeBatch fill:#f3e5f5 + style TileSmall fill:#ef5350,color:#fff + style TileMedium fill:#42a5f5,color:#fff + style TileLarge fill:#ab47bc,color:#fff + style NotEqual fill:#ffcdd2 + style Problem fill:#ff5252,color:#fff + style Solution fill:#81c784 + style Deterministic fill:#66bb6a,color:#fff + style FixedTile fill:#4caf50,color:#fff +``` ## Test Environment - **Instance**: `inf2.xlarge` (AWS Trainium) From 0675233d816dcbace9830865009874908fe7cb32 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Mon, 27 Oct 2025 10:17:30 -0400 Subject: [PATCH 06/21] Refactor tests to follow same pattern as TML's Refactor tests for batch invariance and variance in RMSNorm and MatMul. Now follows the same testing pattern as Thinking Machines Labs. --- .../batch_invariance/test_batch_invariance.py | 132 ++++++++++++------ 1 file changed, 91 insertions(+), 41 deletions(-) diff --git a/contributed/batch_invariance/test_batch_invariance.py b/contributed/batch_invariance/test_batch_invariance.py index c469da7..3bfedcd 100644 --- a/contributed/batch_invariance/test_batch_invariance.py +++ b/contributed/batch_invariance/test_batch_invariance.py @@ -6,52 +6,66 @@ import torch_neuronx import numpy as np from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel -from kernels.rmsnorm_split_reduction import nki_rmsnorm_split_reduction -from kernels.matmul_batch_invariant import nki_matmul_kernel as matmul_batch_invariant +from kernels.matmul_batch_invariant import nki_matmul_kernel def test_matmul(): """MatMul test showing K_TILE effect and precision impact""" print("Testing MatMul batch invariance...") - device = 'xla' - M, K, N = 128, 512, 512 # K=512 triggers different behavior! + K, N = 512, 512 + M_TILE = 128 + large_batch = 256 # 2x M_TILE + small_batch = 128 # 1x M_TILE print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") print() # Test with bfloat16 print(" Testing with bfloat16:") - a_bf16 = torch.linspace(-1, 1, M * K, device=device).reshape(M, K).to(torch.bfloat16) + a_large_bf16 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(large_batch, K).to(torch.bfloat16) b_bf16 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.bfloat16) - result_inv_bf16 = matmul_batch_invariant(a_bf16, b_bf16, batch_invariant=True) # K_TILE=128 - result_var_bf16 = matmul_batch_invariant(a_bf16, b_bf16, batch_invariant=False) # K_TILE=64 + # Test the SAME 128 rows in different batch contexts + a_small_bf16 = a_large_bf16[:small_batch, :] + + # Process as small batch (128 rows) + result_small_bf16 = nki_matmul_kernel(a_small_bf16, b_bf16, batch_invariant=True) - diff_bf16 = torch.max(torch.abs(result_inv_bf16 - result_var_bf16)).item() + # Process as part of large batch (256 rows) + result_large_bf16 = nki_matmul_kernel(a_large_bf16, b_bf16, batch_invariant=False) + + # Compare the SAME rows + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() print(f" Max difference between K_TILE strategies: {diff_bf16:.6f}") print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") - print() # Test with float32 print(" Testing with float32:") - a_f32 = torch.linspace(-1, 1, M * K, device=device).reshape(M, K).to(torch.float32) + a_large_f32 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(large_batch, K).to(torch.float32) b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) - result_inv_f32 = matmul_batch_invariant(a_f32, b_f32, batch_invariant=True) # K_TILE=128 - result_var_f32 = matmul_batch_invariant(a_f32, b_f32, batch_invariant=False) # K_TILE=64 + # Test the SAME 128 rows in different batch contexts + a_small_f32 = a_large_f32[:small_batch, :] + + # Process as small batch (128 rows) + result_small_f32 = nki_matmul_kernel(a_small_f32, b_f32, batch_invariant=True) + + # Process as part of large batch (256 rows) + result_large_f32 = nki_matmul_kernel(a_large_f32, b_f32, batch_invariant=False) - diff_f32 = torch.max(torch.abs(result_inv_f32 - result_var_f32)).item() + # Compare the SAME rows + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() print(f" Max difference between K_TILE strategies: {diff_f32:.6f}") print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") - print() + print(f" Precision impact: bfloat16 error is {diff_bf16/diff_f32 if diff_f32 > 0 else 'N/A'}x larger than float32") print(f" This demonstrates how reduced precision amplifies tiling strategy effects") - -def test_rmsnorm(): + +def test_rmsnorm_invariant(): """RMSNorm demonstrates batch INVARIANCE (not variance)""" print("Testing RMSNorm batch invariance...") @@ -81,15 +95,39 @@ def test_rmsnorm(): print(f" Each row computed independently, reduction is atomic") print(f" Tile size only affects parallelism, not computation order") - -def test_rmsnorm_split_reduction(): - """RMSNorm with SPLIT REDUCTION demonstrates TRUE batch VARIANCE""" - print("Testing RMSNorm with Split Reduction...") - print(" (Tiling the HIDDEN dimension creates different accumulation orders)") +def test_rmsnorm_variant(): + """RMSNorm demonstrates batch INVARIANCE (not variance)""" + print("Testing RMSNorm batch variance...") + + device = 'xla' + hidden_dim = 256 + + # Create a large input with many rows + large_batch = 128 + a_large = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) + g = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) + + # Test the SAME 32 rows in different batch contexts + a_small = a_large[:32, :] + + # Process as small batch (32 rows) + result_small = nki_rmsnorm_kernel(a_small, g, batch_invariant=True) + + # Process as part of large batch (128 rows) + result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=False) + diff_bf16 = torch.max(torch.abs(result_small - result_large[:32])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + + +def test_rmsnorm_accuracy_diff(): + """RMSNorm with accuracy difference demonstrates bfloat16 vs float32 effects on the result""" + print("Testing RMSNorm with varying accuracies...") device = 'xla' - hidden_dim = 512 # Use 512 to see clear difference - batch_size = 64 + hidden_dim = 512 + large_batch = 128 + small_batch = 32 print(f" hidden_dim={hidden_dim}") print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") @@ -98,36 +136,45 @@ def test_rmsnorm_split_reduction(): # Test with bfloat16 print(" Testing with bfloat16:") - a_bf16 = torch.linspace(-1, 1, batch_size * hidden_dim, device=device).reshape(batch_size, hidden_dim).to(torch.bfloat16) + a_large_bf16 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) g_bf16 = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) - result_inv_bf16 = nki_rmsnorm_split_reduction(a_bf16, g_bf16, batch_invariant=True) # HIDDEN_TILE=256 - result_var_bf16 = nki_rmsnorm_split_reduction(a_bf16, g_bf16, batch_invariant=False) # HIDDEN_TILE=128 + # Test the SAME 32 rows in different batch contexts + a_small_bf16 = a_large_bf16[:small_batch, :] + + # Process as small batch (32 rows) + result_small_bf16 = nki_rmsnorm_kernel(a_small_bf16, g_bf16, batch_invariant=True) # HIDDEN_TILE=256 - diff_bf16 = torch.max(torch.abs(result_inv_bf16 - result_var_bf16)).item() + # Process as part of large batch (128 rows) + result_large_bf16 = nki_rmsnorm_kernel(a_large_bf16, g_bf16, batch_invariant=False) # HIDDEN_TILE=128 + + # Compare the SAME rows + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") - print() # Test with float32 print(" Testing with float32:") - a_f32 = torch.linspace(-1, 1, batch_size * hidden_dim, device=device).reshape(batch_size, hidden_dim).to(torch.float32) + a_large_f32 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.float32) g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) - result_inv_f32 = nki_rmsnorm_split_reduction(a_f32, g_f32, batch_invariant=True) # HIDDEN_TILE=256 - result_var_f32 = nki_rmsnorm_split_reduction(a_f32, g_f32, batch_invariant=False) # HIDDEN_TILE=128 + # Test the SAME 32 rows in different batch contexts + a_small_f32 = a_large_f32[:small_batch, :] + + # Process as small batch (32 rows) + result_small_f32 = nki_rmsnorm_kernel(a_small_f32, g_f32, batch_invariant=True) # HIDDEN_TILE=256 - diff_f32 = torch.max(torch.abs(result_inv_f32 - result_var_f32)).item() + # Process as part of large batch (128 rows) + result_large_f32 = nki_rmsnorm_kernel(a_large_f32, g_f32, batch_invariant=False) # HIDDEN_TILE=128 + + # Compare the SAME rows + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") - print() - print(f" Precision impact: bfloat16 error is {diff_bf16/diff_f32 if diff_f32 > 0 else 'N/A'}x larger than float32") - print(f" ✓ Split reduction creates batch variance in BOTH precisions!") - print(f" Different hidden tile sizes → different accumulation order") - print(f" This is analogous to MatMul's K_TILE effect") - + + print(f" Precision impact: bfloat16 error is clear where float32 makes the difference negligible for this test") if __name__ == "__main__": print("Batch Invariance Test") @@ -136,14 +183,17 @@ def test_rmsnorm_split_reduction(): test_matmul() print() print("=" * 80) - test_rmsnorm() + test_rmsnorm_invariant() + print() + print("=" * 80) + test_rmsnorm_variant() print() print("=" * 80) - test_rmsnorm_split_reduction() + test_rmsnorm_accuracy_diff() print("\n" + "=" * 80) print("SUMMARY:") print(" • MatMul: K_TILE variance - different reduction chunking") print(" • RMSNorm (standard): Batch-invariant - atomic reduction") print(" • RMSNorm (split): HIDDEN_TILE variance - reduction chunking") - print("\nDone!") \ No newline at end of file + print("\nDone!") From 09a1c29021d9c5bdc31716dba11a483f0d2b0a97 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:26:43 -0400 Subject: [PATCH 07/21] Delete contributed/batch_invariance/kernels/rmsnorm_split_reduction.py --- .../kernels/rmsnorm_split_reduction.py | 104 ------------------ 1 file changed, 104 deletions(-) delete mode 100644 contributed/batch_invariance/kernels/rmsnorm_split_reduction.py diff --git a/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py b/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py deleted file mode 100644 index 524ec5c..0000000 --- a/contributed/batch_invariance/kernels/rmsnorm_split_reduction.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -RMSNorm with Split Reduction - Demonstrates TRUE Batch Variance - -This kernel tiles the HIDDEN DIMENSION (reduction axis) instead of just the batch dimension. -This creates different accumulation orders and breaks batch-invariance! -""" - -import math -import neuronxcc.nki as nki -import neuronxcc.nki.language as nl - - -@nki.jit -def nki_rmsnorm_split_reduction(a_tensor, g_tensor, batch_invariant=True): - """ - RMSNorm with split reduction along hidden dimension - - batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) - batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) - - This demonstrates REAL batch variance because different tile sizes - change the order of floating-point additions during reduction. - """ - out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, - buffer=nl.shared_hbm) - - assert a_tensor.shape[1] == g_tensor.shape[0] - - num_rows = a_tensor.shape[0] - hidden_dim = a_tensor.shape[1] - BATCH_TILE = 128 - - # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) - # Different sizes = different number of accumulations = variance! - if batch_invariant: - HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) - else: - HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) - - ix = nl.arange(BATCH_TILE)[:, None] - iw = nl.arange(1)[:, None] - - # Process batch in tiles - for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): - - # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks - # Use PSUM for accumulation (always float32 internally) - partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) - - # Iterate over hidden dimension in chunks - num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) - for h in nl.affine_range(num_hidden_tiles): - h_start = h * HIDDEN_TILE - - # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) - iy = nl.arange(HIDDEN_TILE)[None, :] - - # Create mask for valid hidden indices - valid_mask = ((i * BATCH_TILE + ix < num_rows) & - (h * HIDDEN_TILE + iy < hidden_dim)) - - # Load a CHUNK of the hidden dimension with proper indexing - a_chunk = nl.load( - a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], - mask=valid_mask - ) - - # Square this chunk - in_square_chunk = nl.square(a_chunk) - - # Reduce this chunk (sum along hidden dimension) - # Mask ensures we only sum valid elements - chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, - mask=valid_mask) - - # ACCUMULATE: This is where variance enters! - # Different HIDDEN_TILE sizes mean different number of additions - partial_square_sum += chunk_sum - - # Compute mean and RMS - mean = partial_square_sum / hidden_dim - rms_reciprocal = nl.rsqrt(mean) - - # Now load full row for normalization - iy_full = nl.arange(hidden_dim)[None, :] - a_tile = nl.load( - a_tensor[i * BATCH_TILE + ix, iy_full], - mask=(i * BATCH_TILE + ix < num_rows) - ) - - # Normalize by RMS - out_tile = nl.multiply(a_tile, rms_reciprocal) - - # Apply weight - g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) - g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) - out_tile = nl.multiply(out_tile, g_bcast, - mask=(i * BATCH_TILE + ix < num_rows)) - - # Store result - nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, - mask=(i * BATCH_TILE + ix < num_rows)) - - return out_tensor From bf08add646a65e311b1a7cf1adaf5dea30e2f116 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:53:41 -0400 Subject: [PATCH 08/21] Implement isa matmul version Added ISA kernel --- .../kernels/matmul_batch_invariant.py | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py index 7e52b09..7be3727 100644 --- a/contributed/batch_invariance/kernels/matmul_batch_invariant.py +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -7,10 +7,58 @@ import neuronxcc.nki as nki import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +@nki.compiler.skip_middle_end_transformations +@nki.jit +def nki_matmul_kernel_isa(a, b, batch_invariant=True): + """ + Matrix multiplication with batch invariance parameter + + batch_invariant=True: Uses K_TILE=128 + batch_invariant=False: Dynamic K_TILE size used + + This demonstrates how different K tiling affects numerical results. + """ + K, M = a.shape + N = b.shape[1] + M_TILE = 128 + + # ONLY DIFFERENCE: K_TILE strategy + if batch_invariant: + K_TILE = 128 # Always hardcoded + else: + K_TILE = 64 if K <= 512 else 128 # Adaptive + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Load a: [K_TILE, M_TILE] + i_a_p, i_a_f = nl.mgrid[0:K_TILE, 0:M_TILE] + a_tile = nl.load(a[k*K_TILE + i_a_p, m*M_TILE + i_a_f]) + + # Load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + # Matmul + + print(a_tile.shape, b_tile.shape) + c_psum += nisa.nc_matmul(a_tile, b_tile) + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result @nki.jit -def nki_matmul_kernel(a, b, batch_invariant=True): +def nki_matmul_kernel_lang(a, b, batch_invariant=True): """ Matrix multiplication with batch invariance parameter @@ -53,4 +101,4 @@ def nki_matmul_kernel(a, b, batch_invariant=True): c_sbuf = nl.copy(c_psum, dtype=result.dtype) nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) - return result \ No newline at end of file + return result From 1af87da42a90e51b7623b64ffa78c150e434b7d5 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:54:15 -0400 Subject: [PATCH 09/21] Enhance matmul and RMSNorm tests for correctness Added tests for matmul kernel correctness and batch variance effects. Updated existing tests to improve clarity and structure. --- .../batch_invariance/test_batch_invariance.py | 367 +++++++++++++++--- 1 file changed, 322 insertions(+), 45 deletions(-) diff --git a/contributed/batch_invariance/test_batch_invariance.py b/contributed/batch_invariance/test_batch_invariance.py index 3bfedcd..659b491 100644 --- a/contributed/batch_invariance/test_batch_invariance.py +++ b/contributed/batch_invariance/test_batch_invariance.py @@ -3,16 +3,122 @@ """ import torch +import time import torch_neuronx import numpy as np from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel -from kernels.matmul_batch_invariant import nki_matmul_kernel +from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang +# Prove that the kernels match pytorch and are functionally correct +def test_matmul_kernel_correctness(): + """ + Verify NKI matmul kernels produce correct results vs PyTorch. + + Validates mathematical correctness before analyzing batch invariance effects. + """ + print("Testing MatMul Correctness...") + device = 'xla' + + # Test dimensions + M, K, N = 256, 512, 512 + + print(f" Matrix dimensions: [{M}, {K}] @ [{K}, {N}] = [{M}, {N}]") + print() + + # Create test data + np.random.seed(42) + a_np = np.random.randn(M, K).astype(np.float32) + b_np = np.random.randn(K, N).astype(np.float32) + + # PyTorch reference (CPU) + a_torch = torch.tensor(a_np, dtype=torch.float32) + b_torch = torch.tensor(b_np, dtype=torch.float32) + + print(" Computing PyTorch reference (CPU)...") + start = time.time() + ref_output = torch.matmul(a_torch, b_torch) + ref_time = time.time() - start + print(f" Time: {ref_time:.6f}s") + print(f" Output shape: {ref_output.shape}") + print(f" First values: {ref_output[0, :5].numpy()}") + print() + + # Test Lang kernel - expects [M, K] @ [K, N] + print(" Testing Lang kernel (nl.matmul)...") + a_xla = torch.tensor(a_np, dtype=torch.float32, device=device) # [M, K] + b_xla = torch.tensor(b_np, dtype=torch.float32, device=device) # [K, N] + + start = time.time() + output_lang = nki_matmul_kernel_lang(a_xla, b_xla, batch_invariant=True) + lang_time = time.time() - start + + output_lang_cpu = output_lang.cpu() + print(f" Time: {lang_time:.6f}s") + print(f" Output shape: {output_lang_cpu.shape}") + print(f" First values: {output_lang_cpu[0, :5].numpy()}") + + lang_match = torch.allclose(ref_output, output_lang_cpu, atol=1e-4, rtol=1e-2) + max_diff_lang = torch.max(torch.abs(ref_output - output_lang_cpu)).item() + + if lang_match: + print(f" ✓ Matches PyTorch reference") + else: + print(f" ✗ Differs from PyTorch reference") + print(f" Max difference: {max_diff_lang:.6f}") + print() + + # Test ISA kernel - expects [K, M] @ [K, N] + print(" Testing ISA kernel (nisa.nc_matmul)...") + a_xla_t = torch.tensor(a_np.T, dtype=torch.float32, device=device) # [K, M] - transposed! + b_xla = torch.tensor(b_np, dtype=torch.float32, device=device) # [K, N] + + start = time.time() + output_isa = nki_matmul_kernel_isa(a_xla_t, b_xla, batch_invariant=True) + isa_time = time.time() - start + + output_isa_cpu = output_isa.cpu() + print(f" Time: {isa_time:.6f}s") + print(f" Output shape: {output_isa_cpu.shape}") + print(f" First values: {output_isa_cpu[0, :5].numpy()}") + + isa_match = torch.allclose(ref_output, output_isa_cpu, atol=1e-4, rtol=1e-2) + max_diff_isa = torch.max(torch.abs(ref_output - output_isa_cpu)).item() + + if isa_match: + print(f" ✓ Matches PyTorch reference") + else: + print(f" ✗ Differs from PyTorch reference") + print(f" Max difference: {max_diff_isa:.6f}") + print() + + # Summary + print("=" * 80) + if lang_match and isa_match: + print("✓ Both kernels produce correct results") + else: + print("✗ One or more kernels differ from PyTorch reference") + if not lang_match: + print(f" Lang kernel max error: {max_diff_lang:.6f}") + if not isa_match: + print(f" ISA kernel max error: {max_diff_isa:.6f}") + + assert lang_match, f"Lang kernel doesn't match PyTorch (max diff: {max_diff_lang})" + assert isa_match, f"ISA kernel doesn't match PyTorch (max diff: {max_diff_isa})" -def test_matmul(): - """MatMul test showing K_TILE effect and precision impact""" - print("Testing MatMul batch invariance...") +def test_matmul_isa(): + """ + ISA kernel K-tiling batch variance with quantization erasure. + + Expected: bfloat16 error = 0.0 despite float32 showing differences + Reason: nisa.nc_matmul produces float32 errors below bfloat16 threshold (~0.008) + Result: Demonstrates hardware-level numerical stability + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing MatMul batch variance (ISA kernel)...") device = 'xla' + K, N = 512, 512 M_TILE = 128 large_batch = 256 # 2x M_TILE @@ -21,39 +127,91 @@ def test_matmul(): print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") print() - # Test with bfloat16 - print(" Testing with bfloat16:") - a_large_bf16 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(large_batch, K).to(torch.bfloat16) - b_bf16 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.bfloat16) + # Create data ONCE in float32 - ISA kernel needs [K, M] layout! + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(K, large_batch).to(torch.float32) + b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) - # Test the SAME 128 rows in different batch contexts - a_small_bf16 = a_large_bf16[:small_batch, :] + # Test with float32 FIRST + print(" Testing with float32:") + a_small_f32 = a_large_f32[:, :small_batch] # [K, 128] - # Process as small batch (128 rows) - result_small_bf16 = nki_matmul_kernel(a_small_bf16, b_bf16, batch_invariant=True) + result_small_f32 = nki_matmul_kernel_isa(a_small_f32, b_f32, batch_invariant=True) + result_large_f32 = nki_matmul_kernel_isa(a_large_f32, b_f32, batch_invariant=False) - # Process as part of large batch (256 rows) - result_large_bf16 = nki_matmul_kernel(a_large_bf16, b_bf16, batch_invariant=False) + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() + + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + b_bf16 = b_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:, :small_batch] + + result_small_bf16 = nki_matmul_kernel_isa(a_small_bf16, b_bf16, batch_invariant=True) + result_large_bf16 = nki_matmul_kernel_isa(a_large_bf16, b_bf16, batch_invariant=False) - # Compare the SAME rows diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() - print(f" Max difference between K_TILE strategies: {diff_bf16:.6f}") + print(f" Max difference: {diff_bf16:.6f}") print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") print() - # Test with float32 - print(" Testing with float32:") + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + if diff_bf16 == 0.0: + print(f" Note: Float32 error ({diff_f32:.6f}) is below bfloat16 quantization threshold (~0.008)") + print(f" Quantization erases the difference rather than amplifying it") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "ISA (nisa.nc_matmul)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } + +def test_matmul_lang(): + """ + Lang kernel K-tiling batch variance with precision amplification. + + Expected: bfloat16 error ~170x larger than float32 + Reason: nl.matmul produces float32 errors above bfloat16 threshold + Result: Demonstrates how reduced precision amplifies tiling strategy effects + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing MatMul batch variance (Lang kernel)...") + device = 'xla' + + K, N = 512, 512 + M_TILE = 128 + large_batch = 256 # 2x M_TILE + small_batch = 128 # 1x M_TILE + + print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") + print() + + # Create data ONCE in float32 - single source of truth + print(" Creating data in float32...") a_large_f32 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(large_batch, K).to(torch.float32) b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) + # Test with float32 FIRST + print(" Testing with float32:") # Test the SAME 128 rows in different batch contexts a_small_f32 = a_large_f32[:small_batch, :] # Process as small batch (128 rows) - result_small_f32 = nki_matmul_kernel(a_small_f32, b_f32, batch_invariant=True) + result_small_f32 = nki_matmul_kernel_lang(a_small_f32, b_f32, batch_invariant=True) # Process as part of large batch (256 rows) - result_large_f32 = nki_matmul_kernel(a_large_f32, b_f32, batch_invariant=False) + result_large_f32 = nki_matmul_kernel_lang(a_large_f32, b_f32, batch_invariant=False) # Compare the SAME rows diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() @@ -61,12 +219,51 @@ def test_matmul(): print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") print() - print(f" Precision impact: bfloat16 error is {diff_bf16/diff_f32 if diff_f32 > 0 else 'N/A'}x larger than float32") - print(f" This demonstrates how reduced precision amplifies tiling strategy effects") - + # Cast to bfloat16 from the SAME float32 source + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + b_bf16 = b_f32.to(torch.bfloat16) + + # Test the SAME 128 rows in different batch contexts + a_small_bf16 = a_large_bf16[:small_batch, :] + + # Process as small batch (128 rows) + result_small_bf16 = nki_matmul_kernel_lang(a_small_bf16, b_bf16, batch_invariant=True) + + # Process as part of large batch (256 rows) + result_large_bf16 = nki_matmul_kernel_lang(a_large_bf16, b_bf16, batch_invariant=False) + + # Compare the SAME rows + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between K_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x larger than float32") + print(f" This demonstrates how reduced precision amplifies tiling strategy effects") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "Lang (nl.matmul)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } def test_rmsnorm_invariant(): - """RMSNorm demonstrates batch INVARIANCE (not variance)""" + """ + RMSNorm demonstrates batch INVARIANCE with consistent tiling. + + When using the same batch_invariant=True setting, results should be + identical regardless of batch size because each row is computed independently. + + Returns: + dict: Test results showing invariance + """ print("Testing RMSNorm batch invariance...") device = 'xla' @@ -87,16 +284,33 @@ def test_rmsnorm_invariant(): result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=True) # Compare the SAME rows - match = torch.allclose(result_small, result_large[:32], atol=1e-6) + diff = torch.max(torch.abs(result_small - result_large[:32])).item() + match = diff < 1e-6 + print(f" First 32 rows: batch=32 vs batch=128: {'MATCH ✓' if match else 'DIFFER ✗'}") + print(f" Max difference: {diff:.6f}") if match: print(f" ✓ RMSNorm is batch-invariant!") print(f" Each row computed independently, reduction is atomic") print(f" Tile size only affects parallelism, not computation order") + + return { + "test": "RMSNorm Invariant", + "max_difference": diff, + "is_invariant": match + } def test_rmsnorm_variant(): - """RMSNorm demonstrates batch INVARIANCE (not variance)""" + """ + RMSNorm demonstrates batch VARIANCE with different tiling strategies. + + When using different batch_invariant settings (True vs False), results may + differ due to different HIDDEN_TILE sizes affecting reduction chunking. + + Returns: + dict: Test results showing variance + """ print("Testing RMSNorm batch variance...") device = 'xla' @@ -110,20 +324,38 @@ def test_rmsnorm_variant(): # Test the SAME 32 rows in different batch contexts a_small = a_large[:32, :] - # Process as small batch (32 rows) + # Process as small batch (32 rows) with batch_invariant=True result_small = nki_rmsnorm_kernel(a_small, g, batch_invariant=True) - # Process as part of large batch (128 rows) + # Process as part of large batch (128 rows) with batch_invariant=False result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=False) diff_bf16 = torch.max(torch.abs(result_small - result_large[:32])).item() - print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") - print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + + if diff_bf16 > 1e-6: + print(f" ✗ Different HIDDEN_TILE sizes produce different results") + print(f" This demonstrates tiling strategy affects reduction order") + + return { + "test": "RMSNorm Variant", + "max_difference": diff_bf16, + "is_invariant": diff_bf16 < 1e-6 + } def test_rmsnorm_accuracy_diff(): - """RMSNorm with accuracy difference demonstrates bfloat16 vs float32 effects on the result""" - print("Testing RMSNorm with varying accuracies...") + """ + RMSNorm HIDDEN_TILE variance with precision effects. + + Tests how different HIDDEN_TILE sizes affect reduction chunking and + whether precision amplifies these differences. + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing RMSNorm HIDDEN_TILE variance...") device = 'xla' hidden_dim = 512 large_batch = 128 @@ -174,26 +406,71 @@ def test_rmsnorm_accuracy_diff(): print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") print() - print(f" Precision impact: bfloat16 error is clear where float32 makes the difference negligible for this test") + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "RMSNorm (HIDDEN_TILE)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } if __name__ == "__main__": + import pandas as pd + print("Batch Invariance Test") print("=" * 80) - test_matmul() - print() + # Run correctness test + test_matmul_kernel_correctness() print("=" * 80) - test_rmsnorm_invariant() - print() + + # Test Lang kernel + print("\nRunning Lang kernel test...") + lang_results = test_matmul_lang() + print("=" * 80) - test_rmsnorm_variant() - print() + + # Test ISA kernel + print("\nRunning ISA kernel test...") + isa_results = test_matmul_isa() + print("=" * 80) - test_rmsnorm_accuracy_diff() + + # Test RMSNorm invariance + print("=" * 80) + print("\nRunning RMSNorm batch invariance test...") + rmsnorm_invariant = test_rmsnorm_invariant() + + print("=" * 80) + + # Test RMSNorm variance + print("\nRunning RMSNorm batch variance test...") + rmsnorm_variant = test_rmsnorm_variant() + + print("=" * 80) + + # Test RMSNorm HIDDEN_TILE precision effects + print("\nRunning RMSNorm HIDDEN_TILE variance test...") + rmsnorm_results = test_rmsnorm_accuracy_diff() print("\n" + "=" * 80) - print("SUMMARY:") - print(" • MatMul: K_TILE variance - different reduction chunking") - print(" • RMSNorm (standard): Batch-invariant - atomic reduction") - print(" • RMSNorm (split): HIDDEN_TILE variance - reduction chunking") - print("\nDone!") + print("SUMMARY") + print("=" * 80) + + # Create results dataframes + print("\nMatMul & RMSNorm Batch Variance Results:") + variance_df = pd.DataFrame([lang_results, isa_results, rmsnorm_results]) + print(variance_df.to_string(index=False)) + print() + + print("\nRMSNorm Invariance vs Variance:") + invariance_df = pd.DataFrame([rmsnorm_invariant, rmsnorm_variant]) + print(invariance_df.to_string(index=False)) + print() + From a4814d0b85fc4eb69f96669982fe506f0fae336d Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:55:16 -0400 Subject: [PATCH 10/21] Enhance RMSNorm kernel for batch variance demonstration Updated RMSNorm kernel to demonstrate batch variance with split reduction along the hidden dimension. Adjusted tile sizes based on batch invariance parameter to illustrate the impact on floating-point addition order during reduction. --- .../kernels/rmsnorm_batch_invariant.py | 163 ++++++++++-------- 1 file changed, 93 insertions(+), 70 deletions(-) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index 4917eae..ab005d7 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -1,5 +1,8 @@ """ -Batch-Invariant RMSNorm Kernel +RMSNorm to demonstrate Batch Variance + +This kernel tiles the HIDDEN DIMENSION (reduction axis) instead of just the batch dimension. +This creates different accumulation orders and breaks batch-invariance! """ import math @@ -9,73 +12,93 @@ @nki.jit def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): - """ - RMSNorm with batch invariance parameter - - This demonstrates batch invariance testing: - - batch_invariant=True: Always uses tile_size=128 (same strategy regardless of batch) - - batch_invariant=False: Adapts tile_size based on batch size (different strategies) - - This shows that varying the tiling strategy based on batch size does NOT affect results as we are not reducing across the batch dimension - """ - out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, - buffer=nl.shared_hbm) - - # Make sure shapes match - assert a_tensor.shape[1] == g_tensor.shape[0] - - num_rows = a_tensor.shape[0] - hidden_dim = a_tensor.shape[1] - - if batch_invariant: - # INVARIANT: Fixed strategy regardless of batch size - tile_size = 128 - else: - # Also INVARIANT: Strategy changes based on batch size - # Small batches get smaller tiles -> different processing pattern - if num_rows <= 64: - tile_size = 32 # Small batch: smaller tiles + """ + RMSNorm with split reduction along hidden dimension + + batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) + + This demonstrates REAL batch variance because different tile sizes + change the order of floating-point additions during reduction. + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + BATCH_TILE = 128 + + # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) + # Different sizes = different number of accumulations = variance! + if batch_invariant: + HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) else: - tile_size = 128 # Large batch: larger tiles - - # Generate tensor indices based on tile_size - ix = nl.arange(tile_size)[:, None] - iw = nl.arange(1)[:, None] - iy = nl.arange(hidden_dim)[None, :] - - # Load RMSNorm weight once - g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy]) - - # Process tile_size rows at a time - for i in nl.affine_range(math.ceil(num_rows / tile_size)): - - # Load input data from external memory to on-chip memory - a_tile = nl.load(a_tensor[i * tile_size + ix, iy], - mask=(i * tile_size + ix < num_rows)) - - # Compute element-wise square of a_tensor - in_square = nl.square(a_tile) - - # Calculate sum of squared elements, along last dimension - square_sum = nl.sum(in_square, axis=[1]) - - # Scale and get a reciprocal - mean = square_sum / hidden_dim - - # Take square root of mean and then reciprocal with rsqrt API - rms_reciprocal = nl.rsqrt(mean) - - # Scale the input tensor - out_tile = nl.multiply(a_tile, rms_reciprocal) - - # Broadcast weight along first axis to match tensor shape - g_bcast = g_tile.broadcast_to((tile_size, hidden_dim)) - - # Multiply with the RMSNorm weight - out_tile[...] = nl.multiply(out_tile, g_bcast, - mask=(i * tile_size + ix < num_rows)) - - # store the results back to external memory - nl.store(out_tensor[i * tile_size + ix, iy], value=out_tile, - mask=(i * tile_size + ix < num_rows)) - - return out_tensor \ No newline at end of file + HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + + ix = nl.arange(BATCH_TILE)[:, None] + iw = nl.arange(1)[:, None] + + # Process batch in tiles + for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): + + # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks + # Use PSUM for accumulation (always float32 internally) + partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) + + # Iterate over hidden dimension in chunks + num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) + for h in nl.affine_range(num_hidden_tiles): + h_start = h * HIDDEN_TILE + + # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) + iy = nl.arange(HIDDEN_TILE)[None, :] + + # Create mask for valid hidden indices + valid_mask = ((i * BATCH_TILE + ix < num_rows) & + (h * HIDDEN_TILE + iy < hidden_dim)) + + # Load a CHUNK of the hidden dimension with proper indexing + a_chunk = nl.load( + a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + mask=valid_mask + ) + + # Square this chunk + in_square_chunk = nl.square(a_chunk) + + # Reduce this chunk (sum along hidden dimension) + # Mask ensures we only sum valid elements + chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, + mask=valid_mask) + + # ACCUMULATE: This is where variance enters! + # Different HIDDEN_TILE sizes mean different number of additions + partial_square_sum += chunk_sum + + # Compute mean and RMS + mean = partial_square_sum / hidden_dim + rms_reciprocal = nl.rsqrt(mean) + + # Now load full row for normalization + iy_full = nl.arange(hidden_dim)[None, :] + a_tile = nl.load( + a_tensor[i * BATCH_TILE + ix, iy_full], + mask=(i * BATCH_TILE + ix < num_rows) + ) + + # Normalize by RMS + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Apply weight + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) + g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) + out_tile = nl.multiply(out_tile, g_bcast, + mask=(i * BATCH_TILE + ix < num_rows)) + + # Store result + nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, + mask=(i * BATCH_TILE + ix < num_rows)) + + return out_tensor From 0f0b6f94369020b89a03bd84a80f53e37de1854f Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 29 Oct 2025 16:08:17 -0400 Subject: [PATCH 11/21] update readme --- contributed/batch_invariance/README.md | 151 +++++++++++++++++++++---- 1 file changed, 128 insertions(+), 23 deletions(-) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index 0c28b5e..d7b7d28 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -19,7 +19,7 @@ Batch variance occurs when **ALL THREE conditions are met**: - `nl.sum(entire_row)` ✗ Atomic, no variance 3. **Dynamic tile size based on input characteristics** - - CUDA: Adapts K strategy based on batch size ✓ + - CUDA SplitK: Adapts K strategy based on batch size ✓ - NKI (fixed): `K_TILE = 128` always ✗ - NKI (variant): `K_TILE = 64 if K <= 512 else 128` ✓ @@ -89,22 +89,24 @@ flowchart TD ## Test Suite Overview -We test three kernel implementations: +We test four kernel implementations: -1. **MatMul with K_TILE variation** - Demonstrates reduction dimension tiling variance -2. **RMSNorm (standard)** - Demonstrates natural batch invariance with atomic reductions -3. **RMSNorm (split reduction)** - Demonstrates hidden dimension tiling variance +1. **MatMul Lang (nl.matmul)** - High-level NKI API with K_TILE variation +2. **MatMul ISA (nisa.nc_matmul)** - Low-level ISA implementation with K_TILE variation +3. **RMSNorm (standard)** - Demonstrates natural batch invariance with atomic reductions +4. **RMSNorm (split reduction)** - Demonstrates hidden dimension tiling variance Each test compares: - **Invariant mode**: Fixed tile size (batch-invariant) - **Variant mode**: Adaptive tile size (batch-variant) - **Precision impact**: bfloat16 vs float32 +- **Quantization threshold effects**: When float32 errors fall below bfloat16's representable precision ## Results -### Test 1: MatMul - K_TILE Variance +### Test 1a: MatMul Lang (nl.matmul) - K_TILE Variance -**Configuration**: M=128, K=512, N=512 +**Configuration**: M=256, K=512, N=512 ``` bfloat16: @@ -116,10 +118,10 @@ bfloat16: float32: K_TILE=128 (invariant): 4 accumulations K_TILE=64 (variant): 8 accumulations - Max difference: 0.000050 + Max difference: 0.000046 Result: DIFFER ✓ -Precision impact: bfloat16 error is 157x larger than float32 +Precision impact: bfloat16 error is 170x larger than float32 ``` **Key Finding**: Different K_TILE sizes create different accumulation orders in the reduction: @@ -128,6 +130,41 @@ Precision impact: bfloat16 error is 157x larger than float32 Due to floating-point associativity: `(a + b) + c ≠ a + (b + c)` +### Test 1b: MatMul ISA (nisa.nc_matmul) - K_TILE Variance with Quantization Erasure + +**Configuration**: M=256, K=512, N=512 + +``` +bfloat16: + K_TILE=128 (invariant): 4 accumulations over K dimension + K_TILE=64 (variant): 8 accumulations over K dimension + Max difference: 0.000000 + Result: IDENTICAL ✓ + +float32: + K_TILE=128 (invariant): 4 accumulations + K_TILE=64 (variant): 8 accumulations + Max difference: 0.000061 + Result: DIFFER ✓ + +Precision impact: bfloat16 error is 0x smaller than float32 (error erased by quantization) +``` + +**Critical Discovery**: When float32 errors fall below bfloat16's quantization threshold (~0.008), quantization **erases** the differences rather than amplifying them: + +- **Lang kernel**: Float32 error (0.000046) crosses quantization threshold → bfloat16 amplifies to 0.007812 (170x) +- **ISA kernel**: Float32 error (0.000061) stays below threshold → bfloat16 quantizes both results identically (0.000000) + +**Why This Happens**: +1. Both kernels accumulate in float32 internally +2. Final output is quantized to bfloat16 +3. When float32 differences are sub-threshold: + - Both results round to the **same bfloat16 value** + - The error doesn't compound—it **vanishes** +4. ISA-level matmul has superior numerical stability, producing smaller float32 errors + +**Implication**: The ISA kernel's tighter numerical precision keeps K-tiling errors below bfloat16's representable range, making it more robust to batch size variations in reduced precision. + ### Test 2: RMSNorm (Standard) - Natural Batch Invariance **Configuration**: batch_size varies, hidden_dim=256 @@ -187,17 +224,31 @@ Precision impact: Variance only visible in bfloat16 for this test - ✅ **Creates variance**: MatMul K tiling - tiles reduction dimension with accumulation - ✅ **Creates variance**: RMSNorm split reduction - tiles hidden dimension with accumulation -### 📊 Precision Amplifies Variance +### 📊 Precision Effects: Amplification vs Erasure -| Operation | bfloat16 Error | float32 Error | Amplification | -|-----------|---------------|---------------|---------------| -| MatMul (K_TILE) | 0.007812 | 0.000050 | **157x** | -| RMSNorm Split (HIDDEN_TILE) | 0.007812 | ~0.000000 | Only visible in bfloat16 | +| Operation | float32 Error | bfloat16 Error | Amplification | Effect | +|-----------|---------------|----------------|---------------|--------| +| MatMul Lang (nl.matmul) | 0.000046 | 0.007812 | **170x** | Amplified | +| MatMul ISA (nisa.nc_matmul) | 0.000061 | 0.000000 | **0x** | Erased | +| RMSNorm Split (HIDDEN_TILE) | 0.000000 | 0.007812 | **21845x** | Amplified | -**Critical Insight**: Reduced precision (bfloat16) amplifies tiling variance dramatically: -- **Multiply-accumulate** (MatMul): Errors compound quickly, visible in both precisions -- **Pure addition** (RMSNorm sum): Errors compound slowly, only visible in bfloat16 -- **Implication**: bfloat16 sees more extreme batch variance +**Critical Insight**: Bfloat16 has **two distinct behaviors** depending on float32 error magnitude: + +1. **Above quantization threshold (~0.008)**: Errors are **amplified** + - Lang MatMul: 0.000046 → 0.007812 (170x amplification) + - RMSNorm: 0.000000 → 0.007812 (21845x amplification) + - Different accumulation orders produce distinguishable bfloat16 values + +2. **Below quantization threshold (~0.008)**: Errors are **erased** + - ISA MatMul: 0.000061 → 0.000000 (quantization erasure) + - Both K_TILE strategies round to identical bfloat16 values + - Variance becomes invisible in reduced precision + +**Why This Matters**: +- **Multiply-accumulate** (MatMul): Errors compound quickly, may cross threshold +- **Pure addition** (RMSNorm sum): Errors compound slowly, typically crosses threshold +- **ISA-level operations**: Superior numerical stability keeps errors sub-threshold +- **Implication**: Kernel implementation quality determines whether bfloat16 amplifies or erases tiling variance ### 🔬 Replicating Paper Findings with NKI @@ -235,6 +286,13 @@ K_TILE = 128 # Always - `nl.sum(entire_dimension)` is atomic - naturally invariant - Only manual tiling creates variance +4. **ISA-level numerical stability** + - Low-level ISA instructions (`nisa.nc_matmul`) exhibit superior numerical precision + - Tighter error bounds keep float32 differences below bfloat16's quantization threshold + - Quantization can erase tiling variance entirely in reduced precision + - Makes ISA kernels naturally more robust to batch size variations + - However, variance still exists in float32—testing in both precisions is essential + ## Implications for LLM Inference ### ✅ Benefits @@ -275,7 +333,21 @@ However, variance can still occur when: - Using reduced precision (bfloat16) with iterative accumulation - Adapting strategies based on input characteristics -**My findings directly replicate the Thinking Machines paper**: Batch variance stems from **dynamic tiling of reduction dimensions**, and the solution is **fixed tiling strategies**. NKI makes this easier by design, but developers must still be intentional about tile size choices, especially when using bfloat16 precision. +**Key findings that extend the Thinking Machines paper**: + +1. **Batch variance stems from dynamic tiling of reduction dimensions** (confirmed) +2. **Fixed tiling strategies solve the problem** (confirmed) +3. **NEW: Quantization threshold effect** - Bfloat16 doesn't always amplify errors: + - When float32 errors exceed ~0.008: Amplification occurs (170-21845x) + - When float32 errors stay below ~0.008: Quantization erases differences entirely + - ISA-level kernels with superior numerical stability can stay sub-threshold + - This makes some implementations naturally robust to batch variance in bfloat16 + +**Practical Implications**: +- High-quality kernel implementations (ISA-level) may hide batch variance in bfloat16 +- This can create false confidence—variance still exists in float32 +- Testing in float32 is essential to detect underlying numerical instability +- Don't rely on bfloat16 testing alone to validate batch invariance ## Running the Tests @@ -287,28 +359,61 @@ python test_batch_invariance.py **Expected Output:** ``` ================================================================================ -Testing MatMul batch invariance... +Testing MatMul Correctness... + Lang kernel (nl.matmul): ✓ Matches PyTorch reference + ISA kernel (nisa.nc_matmul): ✓ Matches PyTorch reference + +================================================================================ +Testing MatMul batch variance (Lang kernel)... + Testing with float32: + Max difference between K_TILE strategies: 0.000046 + Results differ Testing with bfloat16: Max difference between K_TILE strategies: 0.007812 Results differ + Precision impact: bfloat16 error is 170x larger than float32 + +================================================================================ +Testing MatMul batch variance (ISA kernel)... Testing with float32: - Max difference between K_TILE strategies: 0.000050 + Max difference: 0.000061 Results differ - Precision impact: bfloat16 error is 157x larger than float32 + Testing with bfloat16: + Max difference: 0.000000 + Results identical + Precision impact: bfloat16 error is 0x smaller than float32 + Note: Float32 error (0.000061) is below bfloat16 quantization threshold (~0.008) + Quantization erases the difference rather than amplifying it ================================================================================ Testing RMSNorm batch invariance... First 32 rows: batch=32 vs batch=128: MATCH ✓ ✓ RMSNorm is batch-invariant! + Each row computed independently, reduction is atomic ================================================================================ -Testing RMSNorm with Split Reduction... +Testing RMSNorm batch variance... + Max difference between HIDDEN_TILE strategies: 0.007812 + Results differ + ✗ Different HIDDEN_TILE sizes produce different results + +================================================================================ +Testing RMSNorm HIDDEN_TILE variance... Testing with bfloat16: Max difference between HIDDEN_TILE strategies: 0.007812 Results differ Testing with float32: Max difference between HIDDEN_TILE strategies: 0.000000 Results identical + Precision impact: bfloat16 error is 21845x larger than float32 + +================================================================================ +SUMMARY +MatMul & RMSNorm Batch Variance Results: +kernel float32_error bfloat16_error amplification +Lang (nl.matmul) 4.577637e-05 0.007812 170.666667 +ISA (nisa.nc_matmul) 6.103516e-05 0.000000 0.000000 +RMSNorm (HIDDEN_TILE) 3.576279e-07 0.007812 21845.333333 ``` ## Files From be7ff25bbd22b5fe03e55a97e4aa6266384f8c3f Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 29 Oct 2025 16:18:24 -0400 Subject: [PATCH 12/21] update readme --- contributed/batch_invariance/README.md | 64 ++++++++++++++------------ 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index d7b7d28..4e2ffb7 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -150,20 +150,24 @@ float32: Precision impact: bfloat16 error is 0x smaller than float32 (error erased by quantization) ``` -**Critical Discovery**: When float32 errors fall below bfloat16's quantization threshold (~0.008), quantization **erases** the differences rather than amplifying them: +**Critical Discovery**: Identical tiling variance can be visible or invisible in bfloat16 depending on implementation—not because of error magnitude, but because of **quantization alignment**. -- **Lang kernel**: Float32 error (0.000046) crosses quantization threshold → bfloat16 amplifies to 0.007812 (170x) -- **ISA kernel**: Float32 error (0.000061) stays below threshold → bfloat16 quantizes both results identically (0.000000) +- **Lang kernel**: Float32 error (0.000046) → bfloat16 amplifies to 0.007812 (170x) +- **ISA kernel**: Float32 error (0.000061) → bfloat16 erases to 0.000000 -**Why This Happens**: -1. Both kernels accumulate in float32 internally -2. Final output is quantized to bfloat16 -3. When float32 differences are sub-threshold: - - Both results round to the **same bfloat16 value** - - The error doesn't compound—it **vanishes** -4. ISA-level matmul has superior numerical stability, producing smaller float32 errors +**The Quantization Alignment Effect**: -**Implication**: The ISA kernel's tighter numerical precision keeps K-tiling errors below bfloat16's representable range, making it more robust to batch size variations in reduced precision. +Both implementations produce small float32 errors (< 0.008), yet they behave completely differently in bfloat16. The difference isn't error magnitude—it's whether the two tiling strategies produce float32 values that fall into the **same or different bfloat16 quantization buckets**. + +1. **ISA kernel**: The two K_TILE strategies yield float32 outputs that, despite differing by 0.000061, happen to quantize to **identical bfloat16 values**. The variance exists in float32 but becomes invisible after quantization. + +2. **Lang kernel**: The two K_TILE strategies produce float32 outputs that fall into **different bfloat16 quantization buckets**. The 0.000046 float32 difference crosses a quantization boundary, manifesting as a full 0.007812 bfloat16 step. + +**Why This Matters**: + +ISA's superior numerical stability doesn't just produce smaller errors—it produces errors that **align better with bfloat16 quantization boundaries**, making variance less likely to manifest in reduced precision. However, the variance still exists in float32. + +**Implication**: ISA-level implementations may appear batch-invariant in bfloat16 while still exhibiting variance in float32. Testing in bfloat16 alone is insufficient—the underlying numerical instability remains and may compound in deeper networks. ### Test 2: RMSNorm (Standard) - Natural Batch Invariance @@ -232,23 +236,23 @@ Precision impact: Variance only visible in bfloat16 for this test | MatMul ISA (nisa.nc_matmul) | 0.000061 | 0.000000 | **0x** | Erased | | RMSNorm Split (HIDDEN_TILE) | 0.000000 | 0.007812 | **21845x** | Amplified | -**Critical Insight**: Bfloat16 has **two distinct behaviors** depending on float32 error magnitude: +**Critical Insight**: Bfloat16 has **two distinct behaviors** depending on quantization alignment: -1. **Above quantization threshold (~0.008)**: Errors are **amplified** +1. **Errors cross quantization boundaries**: Variance is **amplified** - Lang MatMul: 0.000046 → 0.007812 (170x amplification) - RMSNorm: 0.000000 → 0.007812 (21845x amplification) - - Different accumulation orders produce distinguishable bfloat16 values + - Different accumulation orders produce float32 values in different bfloat16 buckets -2. **Below quantization threshold (~0.008)**: Errors are **erased** +2. **Errors stay within quantization boundaries**: Variance is **erased** - ISA MatMul: 0.000061 → 0.000000 (quantization erasure) - - Both K_TILE strategies round to identical bfloat16 values + - Different accumulation orders produce float32 values in the same bfloat16 bucket - Variance becomes invisible in reduced precision **Why This Matters**: -- **Multiply-accumulate** (MatMul): Errors compound quickly, may cross threshold -- **Pure addition** (RMSNorm sum): Errors compound slowly, typically crosses threshold -- **ISA-level operations**: Superior numerical stability keeps errors sub-threshold -- **Implication**: Kernel implementation quality determines whether bfloat16 amplifies or erases tiling variance +- **Multiply-accumulate** (MatMul): Errors compound quickly, more likely to cross boundaries +- **Pure addition** (RMSNorm sum): Errors compound slowly, typically crosses boundaries +- **ISA-level operations**: Superior numerical stability produces errors that align better with quantization boundaries +- **Implication**: Kernel implementation quality determines whether bfloat16 amplifies or erases tiling variance through quantization alignment, not just error magnitude ### 🔬 Replicating Paper Findings with NKI @@ -286,12 +290,12 @@ K_TILE = 128 # Always - `nl.sum(entire_dimension)` is atomic - naturally invariant - Only manual tiling creates variance -4. **ISA-level numerical stability** +4. **ISA-level numerical stability and quantization alignment** - Low-level ISA instructions (`nisa.nc_matmul`) exhibit superior numerical precision - - Tighter error bounds keep float32 differences below bfloat16's quantization threshold - - Quantization can erase tiling variance entirely in reduced precision - - Makes ISA kernels naturally more robust to batch size variations - - However, variance still exists in float32—testing in both precisions is essential + - Produces errors that align better with bfloat16 quantization boundaries + - Different tiling strategies may quantize to identical bfloat16 values, erasing variance + - Makes ISA kernels appear more robust to batch size variations in reduced precision + - However, variance still exists in float32—comprehensive testing in both precisions is essential ## Implications for LLM Inference @@ -337,11 +341,11 @@ However, variance can still occur when: 1. **Batch variance stems from dynamic tiling of reduction dimensions** (confirmed) 2. **Fixed tiling strategies solve the problem** (confirmed) -3. **NEW: Quantization threshold effect** - Bfloat16 doesn't always amplify errors: - - When float32 errors exceed ~0.008: Amplification occurs (170-21845x) - - When float32 errors stay below ~0.008: Quantization erases differences entirely - - ISA-level kernels with superior numerical stability can stay sub-threshold - - This makes some implementations naturally robust to batch variance in bfloat16 +3. **NEW: Quantization alignment effect** - Bfloat16 doesn't always amplify errors: + - When float32 differences cross quantization boundaries: Amplification occurs (170-21845x) + - When float32 differences stay within quantization boundaries: Variance is erased entirely + - ISA-level kernels with superior numerical stability produce errors that align better with boundaries + - This makes some implementations appear robust to batch variance in bfloat16, while variance still exists in float32 **Practical Implications**: - High-quality kernel implementations (ISA-level) may hide batch variance in bfloat16 From 73419a7cc01772bd8949df7706eeedd095c0ab89 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:24:53 -0500 Subject: [PATCH 13/21] Enhance RMSNorm kernel with improved indexing Refactor RMSNorm kernel tto replace nl.arange with nl.mgrid --- .../kernels/rmsnorm_batch_invariant.py | 63 +++++++++---------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index ab005d7..4d15081 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -8,7 +8,7 @@ import math import neuronxcc.nki as nki import neuronxcc.nki.language as nl - +import neuronxcc.nki.isa as nisa @nki.jit def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): @@ -36,69 +36,62 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) else: HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + + # Create indices for chunked tile + ix, iy = nl.mgrid[0:BATCH_TILE, 0:HIDDEN_TILE] - ix = nl.arange(BATCH_TILE)[:, None] - iw = nl.arange(1)[:, None] + # Create indices for full tile + ix_full, iy_full = nl.mgrid[0:BATCH_TILE, 0:hidden_dim] - # Process batch in tiles + # Load weight once + iw, iy_g = nl.mgrid[0:1, 0:hidden_dim] + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_g]) + + # Loop over batch dimension for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): - # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks - # Use PSUM for accumulation (always float32 internally) partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) # Iterate over hidden dimension in chunks - num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) - for h in nl.affine_range(num_hidden_tiles): - h_start = h * HIDDEN_TILE - - # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) - iy = nl.arange(HIDDEN_TILE)[None, :] - - # Create mask for valid hidden indices - valid_mask = ((i * BATCH_TILE + ix < num_rows) & - (h * HIDDEN_TILE + iy < hidden_dim)) - - # Load a CHUNK of the hidden dimension with proper indexing + for h in nl.affine_range(math.ceil(hidden_dim / HIDDEN_TILE)): + # Load chunk with mask a_chunk = nl.load( a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], - mask=valid_mask + mask=(i * BATCH_TILE + ix < num_rows) & (h * HIDDEN_TILE + iy < hidden_dim) ) # Square this chunk - in_square_chunk = nl.square(a_chunk) + chunk_square = nl.square(a_chunk) # Reduce this chunk (sum along hidden dimension) - # Mask ensures we only sum valid elements - chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, - mask=valid_mask) + chunk_sum = nl.sum(chunk_square, axis=[1], keepdims=True) # ACCUMULATE: This is where variance enters! # Different HIDDEN_TILE sizes mean different number of additions partial_square_sum += chunk_sum # Compute mean and RMS - mean = partial_square_sum / hidden_dim + mean = partial_square_sum * (1.0 / hidden_dim) rms_reciprocal = nl.rsqrt(mean) - # Now load full row for normalization - iy_full = nl.arange(hidden_dim)[None, :] + # Load full row for normalization with mask a_tile = nl.load( - a_tensor[i * BATCH_TILE + ix, iy_full], - mask=(i * BATCH_TILE + ix < num_rows) + a_tensor[i * BATCH_TILE + ix_full, iy_full], + mask=(i * BATCH_TILE + ix_full < num_rows) ) # Normalize by RMS out_tile = nl.multiply(a_tile, rms_reciprocal) # Apply weight - g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) - out_tile = nl.multiply(out_tile, g_bcast, - mask=(i * BATCH_TILE + ix < num_rows)) + out_tile = nl.multiply(out_tile, g_bcast, mask=(i * BATCH_TILE + ix_full < num_rows)) - # Store result - nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, - mask=(i * BATCH_TILE + ix < num_rows)) - + # Store result with mask + nl.store( + out_tensor[i * BATCH_TILE + ix_full, iy_full], + value=out_tile, + mask=(i * BATCH_TILE + ix_full < num_rows) + ) + return out_tensor From 3843cac999029659319ce37eeeef44e3eb2dbf01 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:33:55 -0500 Subject: [PATCH 14/21] Optimize memory operations using nisa.dma_copy Replaced direct load/store operations with nisa.dma_copy for better performance. --- .../kernels/rmsnorm_batch_invariant.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index 4d15081..85cb706 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -43,9 +43,13 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): # Create indices for full tile ix_full, iy_full = nl.mgrid[0:BATCH_TILE, 0:hidden_dim] - # Load weight once + # Load weight once using nisa.dma_copy iw, iy_g = nl.mgrid[0:1, 0:hidden_dim] - g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_g]) + g_tile = nl.ndarray((1, hidden_dim), dtype=g_tensor.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=g_tensor.reshape((1, hidden_dim))[iw, iy_g], + dst=g_tile[iw, iy_g] + ) # Loop over batch dimension for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): @@ -54,9 +58,13 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): # Iterate over hidden dimension in chunks for h in nl.affine_range(math.ceil(hidden_dim / HIDDEN_TILE)): - # Load chunk with mask - a_chunk = nl.load( - a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + # Allocate buffer for chunk + a_chunk = nl.ndarray((BATCH_TILE, HIDDEN_TILE), dtype=a_tensor.dtype, buffer=nl.sbuf) + + # Load chunk with mask using nisa.dma_copy + nisa.dma_copy( + src=a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + dst=a_chunk[ix, iy], mask=(i * BATCH_TILE + ix < num_rows) & (h * HIDDEN_TILE + iy < hidden_dim) ) @@ -74,9 +82,13 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): mean = partial_square_sum * (1.0 / hidden_dim) rms_reciprocal = nl.rsqrt(mean) - # Load full row for normalization with mask - a_tile = nl.load( - a_tensor[i * BATCH_TILE + ix_full, iy_full], + # Allocate buffer for full tile + a_tile = nl.ndarray((BATCH_TILE, hidden_dim), dtype=a_tensor.dtype, buffer=nl.sbuf) + + # Load full row for normalization with mask using nisa.dma_copy + nisa.dma_copy( + src=a_tensor[i * BATCH_TILE + ix_full, iy_full], + dst=a_tile[ix_full, iy_full], mask=(i * BATCH_TILE + ix_full < num_rows) ) @@ -87,10 +99,10 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) out_tile = nl.multiply(out_tile, g_bcast, mask=(i * BATCH_TILE + ix_full < num_rows)) - # Store result with mask - nl.store( - out_tensor[i * BATCH_TILE + ix_full, iy_full], - value=out_tile, + # Store result with mask using nisa.dma_copy + nisa.dma_copy( + src=out_tile[ix_full, iy_full], + dst=out_tensor[i * BATCH_TILE + ix_full, iy_full], mask=(i * BATCH_TILE + ix_full < num_rows) ) From 34142ed8cf5c652d9a1ca64b5cea46640bc4509c Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:36:33 -0500 Subject: [PATCH 15/21] Optimize matmul with DMA copy for tile loading Using DMA copy for improved performance. --- .../kernels/matmul_batch_invariant.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py index 7be3727..d957dae 100644 --- a/contributed/batch_invariance/kernels/matmul_batch_invariant.py +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -38,22 +38,32 @@ def nki_matmul_kernel_isa(a, b, batch_invariant=True): # Reduction over K for k in nl.affine_range(K // K_TILE): - # Load a: [K_TILE, M_TILE] + # Allocate and load a: [K_TILE, M_TILE] i_a_p, i_a_f = nl.mgrid[0:K_TILE, 0:M_TILE] - a_tile = nl.load(a[k*K_TILE + i_a_p, m*M_TILE + i_a_f]) + a_tile = nl.ndarray((K_TILE, M_TILE), dtype=a.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=a[k*K_TILE + i_a_p, m*M_TILE + i_a_f], + dst=a_tile[i_a_p, i_a_f] + ) - # Load b: [K_TILE, N] + # Allocate and load b: [K_TILE, N] i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] - b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + b_tile = nl.ndarray((K_TILE, N), dtype=b.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=b[k*K_TILE + i_b_p, i_b_f], + dst=b_tile[i_b_p, i_b_f] + ) # Matmul - - print(a_tile.shape, b_tile.shape) c_psum += nisa.nc_matmul(a_tile, b_tile) + # Store this M chunk i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] c_sbuf = nl.copy(c_psum, dtype=result.dtype) - nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + nisa.dma_copy( + src=c_sbuf[i_out_p, i_out_f], + dst=result[m*M_TILE + i_out_p, i_out_f] + ) return result From 31299db760ddf873e6a361f4918676472c328a27 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 4 Nov 2025 16:21:48 -0500 Subject: [PATCH 16/21] Refactor RMSNorm tests for batch invariance and variance --- .../batch_invariance/test_batch_invariance.py | 234 ++++++++---------- 1 file changed, 102 insertions(+), 132 deletions(-) diff --git a/contributed/batch_invariance/test_batch_invariance.py b/contributed/batch_invariance/test_batch_invariance.py index 659b491..9223622 100644 --- a/contributed/batch_invariance/test_batch_invariance.py +++ b/contributed/batch_invariance/test_batch_invariance.py @@ -6,7 +6,7 @@ import time import torch_neuronx import numpy as np -from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel +from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_lang, nki_rmsnorm_kernel_isa from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang # Prove that the kernels match pytorch and are functionally correct @@ -254,108 +254,93 @@ def test_matmul_lang(): "amplification": ratio } -def test_rmsnorm_invariant(): + + + +def test_rmsnorm_lang(): """ - RMSNorm demonstrates batch INVARIANCE with consistent tiling. + RMSNorm Lang kernel HIDDEN_TILE variance with precision effects. + + Uses nl.load, nl.store, nl.sum for data movement and reduction. + Different HIDDEN_TILE sizes create different reduction orders. - When using the same batch_invariant=True setting, results should be - identical regardless of batch size because each row is computed independently. + Expected: Shows variance in both float32 and bfloat16 Returns: - dict: Test results showing invariance + dict: Test results with float32 and bfloat16 errors """ - print("Testing RMSNorm batch invariance...") - + print("Testing RMSNorm batch variance (Lang kernel)...") device = 'xla' - hidden_dim = 256 - - # Create a large input with many rows + hidden_dim = 512 large_batch = 128 - a_large = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) - g = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) - - # Test the SAME 32 rows in different batch contexts - a_small = a_large[:32, :] - - # Process as small batch (32 rows) - result_small = nki_rmsnorm_kernel(a_small, g, batch_invariant=True) - - # Process as part of large batch (128 rows) - result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=True) - - # Compare the SAME rows - diff = torch.max(torch.abs(result_small - result_large[:32])).item() - match = diff < 1e-6 - - print(f" First 32 rows: batch=32 vs batch=128: {'MATCH ✓' if match else 'DIFFER ✗'}") - print(f" Max difference: {diff:.6f}") - - if match: - print(f" ✓ RMSNorm is batch-invariant!") - print(f" Each row computed independently, reduction is atomic") - print(f" Tile size only affects parallelism, not computation order") - - return { - "test": "RMSNorm Invariant", - "max_difference": diff, - "is_invariant": match - } - -def test_rmsnorm_variant(): - """ - RMSNorm demonstrates batch VARIANCE with different tiling strategies. + small_batch = 32 - When using different batch_invariant settings (True vs False), results may - differ due to different HIDDEN_TILE sizes affecting reduction chunking. + print(f" hidden_dim={hidden_dim}") + print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") + print(f" batch_invariant=False: HIDDEN_TILE=128 (4 chunks, 3 accumulations)") + print() - Returns: - dict: Test results showing variance - """ - print("Testing RMSNorm batch variance...") + # Create data ONCE in float32 + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.float32) + g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) - device = 'xla' - hidden_dim = 256 + # Test with float32 FIRST + print(" Testing with float32:") + a_small_f32 = a_large_f32[:small_batch, :] - # Create a large input with many rows - large_batch = 128 - a_large = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) - g = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) + result_small_f32 = nki_rmsnorm_kernel_lang(a_small_f32, g_f32, batch_invariant=True) + result_large_f32 = nki_rmsnorm_kernel_lang(a_large_f32, g_f32, batch_invariant=False) - # Test the SAME 32 rows in different batch contexts - a_small = a_large[:32, :] + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() - # Process as small batch (32 rows) with batch_invariant=True - result_small = nki_rmsnorm_kernel(a_small, g, batch_invariant=True) + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + g_bf16 = g_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:small_batch, :] - # Process as part of large batch (128 rows) with batch_invariant=False - result_large = nki_rmsnorm_kernel(a_large, g, batch_invariant=False) + result_small_bf16 = nki_rmsnorm_kernel_lang(a_small_bf16, g_bf16, batch_invariant=True) + result_large_bf16 = nki_rmsnorm_kernel_lang(a_large_bf16, g_bf16, batch_invariant=False) - diff_bf16 = torch.max(torch.abs(result_small - result_large[:32])).item() - print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") - print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() - if diff_bf16 > 1e-6: - print(f" ✗ Different HIDDEN_TILE sizes produce different results") - print(f" This demonstrates tiling strategy affects reduction order") + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + print(f" Lang kernel shows variance due to different reduction chunking") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") return { - "test": "RMSNorm Variant", - "max_difference": diff_bf16, - "is_invariant": diff_bf16 < 1e-6 + "kernel": "RMSNorm Lang (nl.sum)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio } -def test_rmsnorm_accuracy_diff(): +def test_rmsnorm_isa(): """ - RMSNorm HIDDEN_TILE variance with precision effects. + RMSNorm ISA kernel demonstrates batch INVARIANCE. - Tests how different HIDDEN_TILE sizes affect reduction chunking and - whether precision amplifies these differences. + Uses nisa.dma_copy and nisa.tensor_reduce with skip_middle_end_transformations. + Despite different HIDDEN_TILE sizes, ISA produces identical results. + + Expected: No variance in either float32 or bfloat16 + Reason: ISA-level operations are deterministic regardless of tiling strategy Returns: - dict: Test results with float32 and bfloat16 errors + dict: Test results with float32 and bfloat16 errors (should be 0.0) """ - print("Testing RMSNorm HIDDEN_TILE variance...") + print("Testing RMSNorm batch INVARIANCE (ISA kernel)...") device = 'xla' hidden_dim = 512 large_batch = 128 @@ -364,62 +349,60 @@ def test_rmsnorm_accuracy_diff(): print(f" hidden_dim={hidden_dim}") print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") print(f" batch_invariant=False: HIDDEN_TILE=128 (4 chunks, 3 accumulations)") + print(f" Note: ISA kernel uses @skip_middle_end_transformations") print() - # Test with bfloat16 - print(" Testing with bfloat16:") - a_large_bf16 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.bfloat16) - g_bf16 = torch.ones(hidden_dim, device=device, dtype=torch.bfloat16) - - # Test the SAME 32 rows in different batch contexts - a_small_bf16 = a_large_bf16[:small_batch, :] - - # Process as small batch (32 rows) - result_small_bf16 = nki_rmsnorm_kernel(a_small_bf16, g_bf16, batch_invariant=True) # HIDDEN_TILE=256 - - # Process as part of large batch (128 rows) - result_large_bf16 = nki_rmsnorm_kernel(a_large_bf16, g_bf16, batch_invariant=False) # HIDDEN_TILE=128 - - # Compare the SAME rows - diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() - print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") - print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") - print() - - # Test with float32 - print(" Testing with float32:") + # Create data ONCE in float32 + print(" Creating data in float32...") a_large_f32 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.float32) g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) - # Test the SAME 32 rows in different batch contexts + # Test with float32 FIRST + print(" Testing with float32:") a_small_f32 = a_large_f32[:small_batch, :] - # Process as small batch (32 rows) - result_small_f32 = nki_rmsnorm_kernel(a_small_f32, g_f32, batch_invariant=True) # HIDDEN_TILE=256 + result_small_f32 = nki_rmsnorm_kernel_isa(a_small_f32, g_f32, batch_invariant=True) + result_large_f32 = nki_rmsnorm_kernel_isa(a_large_f32, g_f32, batch_invariant=False) - # Process as part of large batch (128 rows) - result_large_f32 = nki_rmsnorm_kernel(a_large_f32, g_f32, batch_invariant=False) # HIDDEN_TILE=128 - - # Compare the SAME rows diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") print() - if diff_f32 > 0: - ratio = diff_bf16 / diff_f32 + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + g_bf16 = g_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:small_batch, :] + + result_small_bf16 = nki_rmsnorm_kernel_isa(a_small_bf16, g_bf16, batch_invariant=True) + result_large_bf16 = nki_rmsnorm_kernel_isa(a_large_bf16, g_bf16, batch_invariant=False) + + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 == 0.0 and diff_bf16 == 0.0: + print(f" ✓ ISA kernel is BATCH INVARIANT!") + print(f" @skip_middle_end_transformations ensures deterministic reduction") + print(f" regardless of HIDDEN_TILE size") + ratio = 0.0 + elif diff_f32 > 0: + ratio = diff_bf16 / diff_f32 if diff_f32 > 0 else 0.0 print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") else: ratio = 0.0 - print(f" Precision impact: N/A (no float32 difference detected)") + print(f" Precision impact: N/A") return { - "kernel": "RMSNorm (HIDDEN_TILE)", + "kernel": "RMSNorm ISA (nisa.tensor_reduce)", "float32_error": diff_f32, "bfloat16_error": diff_bf16, "amplification": ratio } + if __name__ == "__main__": import pandas as pd @@ -442,35 +425,22 @@ def test_rmsnorm_accuracy_diff(): print("=" * 80) - # Test RMSNorm invariance - print("=" * 80) - print("\nRunning RMSNorm batch invariance test...") - rmsnorm_invariant = test_rmsnorm_invariant() - - print("=" * 80) - - # Test RMSNorm variance - print("\nRunning RMSNorm batch variance test...") - rmsnorm_variant = test_rmsnorm_variant() + # Test RMSNorm Lang kernel + print("\nRunning RMSNorm Lang kernel test...") + rmsnorm_lang_results = test_rmsnorm_lang() print("=" * 80) - # Test RMSNorm HIDDEN_TILE precision effects - print("\nRunning RMSNorm HIDDEN_TILE variance test...") - rmsnorm_results = test_rmsnorm_accuracy_diff() + # Test RMSNorm ISA kernel + print("\nRunning RMSNorm ISA kernel test...") + rmsnorm_isa_results = test_rmsnorm_isa() print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) - # Create results dataframes - print("\nMatMul & RMSNorm Batch Variance Results:") - variance_df = pd.DataFrame([lang_results, isa_results, rmsnorm_results]) + # Create results dataframe + print("\nBatch Variance Results:") + variance_df = pd.DataFrame([lang_results, isa_results, rmsnorm_lang_results, rmsnorm_isa_results]) print(variance_df.to_string(index=False)) print() - - print("\nRMSNorm Invariance vs Variance:") - invariance_df = pd.DataFrame([rmsnorm_invariant, rmsnorm_variant]) - print(invariance_df.to_string(index=False)) - print() - From 4608fe82ec32dd7d0ae57a38cbea997561417d05 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 4 Nov 2025 16:23:00 -0500 Subject: [PATCH 17/21] Add isa and lang versions to demonstrate variance --- .../kernels/rmsnorm_batch_invariant.py | 102 +++++++++++++++++- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index 85cb706..f981514 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -10,8 +10,98 @@ import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa + @nki.jit -def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): +def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, batch_invariant=True): + """ + RMSNorm with split reduction along hidden dimension + + batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) + + This demonstrates REAL batch variance because different tile sizes + change the order of floating-point additions during reduction. + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + BATCH_TILE = 128 + + # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) + # Different sizes = different number of accumulations = variance! + if batch_invariant: + HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) + else: + HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + + ix = nl.arange(BATCH_TILE)[:, None] + iw = nl.arange(1)[:, None] + + # Process batch in tiles + for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): + # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks + # Use PSUM for accumulation (always float32 internally) + partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) + + # Iterate over hidden dimension in chunks + num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) + for h in nl.affine_range(num_hidden_tiles): + h_start = h * HIDDEN_TILE + + # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) + iy = nl.arange(HIDDEN_TILE)[None, :] + + # Create mask for valid hidden indices + valid_mask = ((i * BATCH_TILE + ix < num_rows) & + (h * HIDDEN_TILE + iy < hidden_dim)) + + # Load a CHUNK of the hidden dimension with proper indexing + a_chunk = nl.load(a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + mask=valid_mask) + + # Square this chunk + in_square_chunk = nl.square(a_chunk) + + # Reduce this chunk (sum along hidden dimension) + # Mask ensures we only sum valid elements + chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, mask=valid_mask) + + # ACCUMULATE: This is where variance enters! + # Different HIDDEN_TILE sizes mean different number of additions + partial_square_sum += chunk_sum + + # Compute mean and RMS + mean = partial_square_sum / hidden_dim + rms_reciprocal = nl.rsqrt(mean) + + # Now load full row for normalization + iy_full = nl.arange(hidden_dim)[None, :] + a_tile = nl.load(a_tensor[i * BATCH_TILE + ix, iy_full], + mask=(i * BATCH_TILE + ix < num_rows)) + + # Normalize by RMS + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Apply weight + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) + g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) + out_tile = nl.multiply(out_tile, g_bcast, + mask=(i * BATCH_TILE + ix < num_rows)) + + # Store result + nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, + mask=(i * BATCH_TILE + ix < num_rows)) + + return out_tensor + + +@nki.compiler.skip_middle_end_transformations +@nki.jit +def nki_rmsnorm_kernel_isa(a_tensor, g_tensor, batch_invariant=True): """ RMSNorm with split reduction along hidden dimension @@ -71,8 +161,14 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, batch_invariant=True): # Square this chunk chunk_square = nl.square(a_chunk) - # Reduce this chunk (sum along hidden dimension) - chunk_sum = nl.sum(chunk_square, axis=[1], keepdims=True) + # Reduce this chunk (sum along hidden dimension) using nisa.tensor_reduce + chunk_sum = nisa.tensor_reduce( + nl.add, + chunk_square[ix, iy], + axis=[1], + keepdims=True, + dtype=nl.float32 + ) # ACCUMULATE: This is where variance enters! # Different HIDDEN_TILE sizes mean different number of additions From 89a1982689c16b37490e63be74a88582ed25eaa4 Mon Sep 17 00:00:00 2001 From: Jlonge4 Date: Tue, 4 Nov 2025 20:39:17 -0500 Subject: [PATCH 18/21] streamline readme --- contributed/batch_invariance/README.md | 428 ++----------------------- 1 file changed, 22 insertions(+), 406 deletions(-) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index 4e2ffb7..57979e9 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -1,431 +1,47 @@ -# NKI Batch Invariance Test +# NKI Batch Invariance: ISA vs Lang Kernels -Demonstrating batch invariance principles in NKI (Neuron Kernel Interface), replicating findings from [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/). - -## What is Batch Invariance? - -**Batch invariance** means that computing the same element in different batch sizes produces **identical numerical results**. The paper demonstrates that CUDA/PyTorch matrix multiplication is **NOT batch-invariant** due to dynamic optimization strategies that change based on batch size. - -## When Does Batch Variance Occur? - -Batch variance occurs when **ALL THREE conditions are met**: - -1. **Tiling the reduction dimension** (not parallelizable dimensions) - - MatMul: Tiling K (contraction dimension) ✓ - - RMSNorm: Tiling hidden dimension in split reduction ✓ - -2. **Iterative accumulation across tiles** (not atomic reductions) - - `c_psum += matmul(a_tile, b_tile)` ✓ Creates variance - - `nl.sum(entire_row)` ✗ Atomic, no variance - -3. **Dynamic tile size based on input characteristics** - - CUDA SplitK: Adapts K strategy based on batch size ✓ - - NKI (fixed): `K_TILE = 128` always ✗ - - NKI (variant): `K_TILE = 64 if K <= 512 else 128` ✓ - -```mermaid -flowchart TD - Start[Input Tensor: batch_size x hidden_dim 1024] --> CheckBatch{What is batch_size?} - - CheckBatch -->|batch < 64| SmallBatch[Small Batch Strategy] - CheckBatch -->|64 ≤ batch < 128| MediumBatch[Medium Batch Strategy] - CheckBatch -->|batch ≥ 128| LargeBatch[Large Batch Strategy] - - SmallBatch --> TileSmall[TILE_SIZE = 64] - MediumBatch --> TileMedium[TILE_SIZE = 128] - LargeBatch --> TileLarge[TILE_SIZE = 256] - - TileSmall --> ChunkSmall[Split hidden_dim into 16 chunks] - TileMedium --> ChunkMedium[Split hidden_dim into 8 chunks] - TileLarge --> ChunkLarge[Split hidden_dim into 4 chunks] - - ChunkSmall --> ReduceSmall[Reduce each chunk:
sum elements 0:64
sum elements 64:128
... 16 partial sums] - ChunkMedium --> ReduceMedium[Reduce each chunk:
sum elements 0:128
sum elements 128:256
... 8 partial sums] - ChunkLarge --> ReduceLarge[Reduce each chunk:
sum elements 0:256
sum elements 256:512
... 4 partial sums] - - ReduceSmall --> AccumSmall[Accumulate 16 partials:
p1 + p2 = t1
t1 + p3 = t2
... 15 additions] - ReduceMedium --> AccumMedium[Accumulate 8 partials:
p1 + p2 = t1
t1 + p3 = t2
... 7 additions] - ReduceLarge --> AccumLarge[Accumulate 4 partials:
p1 + p2 = t1
t1 + p3 = t2
... 3 additions] - - AccumSmall --> ResultSmall[result_small
15 rounding errors] - AccumMedium --> ResultMedium[result_medium
7 rounding errors] - AccumLarge --> ResultLarge[result_large
3 rounding errors] - - ResultSmall --> Compare{Compare Results} - ResultMedium --> Compare - ResultLarge --> Compare - - Compare --> NotEqual[❌ result_small ≠ result_medium ≠ result_large
Different accumulation orders
Different floating-point rounding
NON-DETERMINISTIC] - - NotEqual --> Problem[🔥 PROBLEM: Same input data,
different batch sizes yield
different numerical results!] - - Problem --> Solution[✅ SOLUTION: Hardcode TILE_SIZE] - - Solution --> FixedTile[TILE_SIZE = 128 always] - FixedTile --> FixedChunks[Always 8 chunks
Always 7 accumulations
for ALL batch sizes] - FixedChunks --> Deterministic[✅ DETERMINISTIC RESULTS
batch=32: 8 chunks, 7 adds
batch=96: 8 chunks, 7 adds
batch=256: 8 chunks, 7 adds] - - style Start fill:#e3f2fd - style CheckBatch fill:#fff3e0 - style SmallBatch fill:#ffebee - style MediumBatch fill:#e8eaf6 - style LargeBatch fill:#f3e5f5 - style TileSmall fill:#ef5350,color:#fff - style TileMedium fill:#42a5f5,color:#fff - style TileLarge fill:#ab47bc,color:#fff - style NotEqual fill:#ffcdd2 - style Problem fill:#ff5252,color:#fff - style Solution fill:#81c784 - style Deterministic fill:#66bb6a,color:#fff - style FixedTile fill:#4caf50,color:#fff -``` -## Test Environment - -- **Instance**: `inf2.xlarge` (AWS Trainium) -- **AMI ID**: `ami-0ec4ab14b1c5a10f2` -- **AMI Name**: `Deep Learning AMI Neuron (Ubuntu 22.04) 20250919` -- **Compiler**: `neuronxcc-2.21.18209.0` -- **Framework**: NKI (Neuron Kernel Interface) - -## Test Suite Overview - -We test four kernel implementations: - -1. **MatMul Lang (nl.matmul)** - High-level NKI API with K_TILE variation -2. **MatMul ISA (nisa.nc_matmul)** - Low-level ISA implementation with K_TILE variation -3. **RMSNorm (standard)** - Demonstrates natural batch invariance with atomic reductions -4. **RMSNorm (split reduction)** - Demonstrates hidden dimension tiling variance - -Each test compares: -- **Invariant mode**: Fixed tile size (batch-invariant) -- **Variant mode**: Adaptive tile size (batch-variant) -- **Precision impact**: bfloat16 vs float32 -- **Quantization threshold effects**: When float32 errors fall below bfloat16's representable precision - -## Results - -### Test 1a: MatMul Lang (nl.matmul) - K_TILE Variance - -**Configuration**: M=256, K=512, N=512 - -``` -bfloat16: - K_TILE=128 (invariant): 4 accumulations over K dimension - K_TILE=64 (variant): 8 accumulations over K dimension - Max difference: 0.007812 - Result: DIFFER ✓ - -float32: - K_TILE=128 (invariant): 4 accumulations - K_TILE=64 (variant): 8 accumulations - Max difference: 0.000046 - Result: DIFFER ✓ - -Precision impact: bfloat16 error is 170x larger than float32 -``` - -**Key Finding**: Different K_TILE sizes create different accumulation orders in the reduction: -- K_TILE=128: `((chunk0 + chunk1) + chunk2) + chunk3` (4 tiles) -- K_TILE=64: `(((((((ch0 + ch1) + ch2) + ch3) + ch4) + ch5) + ch6) + ch7)` (8 tiles) - -Due to floating-point associativity: `(a + b) + c ≠ a + (b + c)` - -### Test 1b: MatMul ISA (nisa.nc_matmul) - K_TILE Variance with Quantization Erasure - -**Configuration**: M=256, K=512, N=512 - -``` -bfloat16: - K_TILE=128 (invariant): 4 accumulations over K dimension - K_TILE=64 (variant): 8 accumulations over K dimension - Max difference: 0.000000 - Result: IDENTICAL ✓ - -float32: - K_TILE=128 (invariant): 4 accumulations - K_TILE=64 (variant): 8 accumulations - Max difference: 0.000061 - Result: DIFFER ✓ - -Precision impact: bfloat16 error is 0x smaller than float32 (error erased by quantization) -``` - -**Critical Discovery**: Identical tiling variance can be visible or invisible in bfloat16 depending on implementation—not because of error magnitude, but because of **quantization alignment**. - -- **Lang kernel**: Float32 error (0.000046) → bfloat16 amplifies to 0.007812 (170x) -- **ISA kernel**: Float32 error (0.000061) → bfloat16 erases to 0.000000 - -**The Quantization Alignment Effect**: - -Both implementations produce small float32 errors (< 0.008), yet they behave completely differently in bfloat16. The difference isn't error magnitude—it's whether the two tiling strategies produce float32 values that fall into the **same or different bfloat16 quantization buckets**. - -1. **ISA kernel**: The two K_TILE strategies yield float32 outputs that, despite differing by 0.000061, happen to quantize to **identical bfloat16 values**. The variance exists in float32 but becomes invisible after quantization. - -2. **Lang kernel**: The two K_TILE strategies produce float32 outputs that fall into **different bfloat16 quantization buckets**. The 0.000046 float32 difference crosses a quantization boundary, manifesting as a full 0.007812 bfloat16 step. - -**Why This Matters**: - -ISA's superior numerical stability doesn't just produce smaller errors—it produces errors that **align better with bfloat16 quantization boundaries**, making variance less likely to manifest in reduced precision. However, the variance still exists in float32. - -**Implication**: ISA-level implementations may appear batch-invariant in bfloat16 while still exhibiting variance in float32. Testing in bfloat16 alone is insufficient—the underlying numerical instability remains and may compound in deeper networks. - -### Test 2: RMSNorm (Standard) - Natural Batch Invariance - -**Configuration**: batch_size varies, hidden_dim=256 - -``` -Same 32 rows computed in: - - batch=32 context - - batch=128 context - -Result: MATCH ✓ (identical) -Max difference: 0.0 -``` - -**RMSNorm remains batch-invariant UNTIL you:** -- Tile the **hidden dimension** (the reduction axis) instead of the batch dimension -- Make that tile size **dynamic** based on input characteristics -- Use **iterative accumulation** across hidden dimension chunks (see Test 3 for this scenario) - -### Test 3: RMSNorm (Split Reduction) - Hidden Dimension Tiling Variance - -**Configuration**: batch_size=64, hidden_dim=512 - -``` -bfloat16: - HIDDEN_TILE=256 (invariant): 2 chunks, 1 accumulation - HIDDEN_TILE=128 (variant): 4 chunks, 3 accumulations - Max difference: 0.007812 - Result: DIFFER ✓ - -float32: - HIDDEN_TILE=256 (invariant): 2 chunks, 1 accumulation - HIDDEN_TILE=128 (variant): 4 chunks, 3 accumulations - Max difference: 0.000000 - Result: IDENTICAL - -Precision impact: Variance only visible in bfloat16 for this test -``` - -**Key Finding**: Split reduction creates variance by tiling the **reduction dimension** (hidden_dim): -- Standard RMSNorm: `nl.sum(row)` - atomic, invariant -- Split RMSNorm: `sum(chunk0) + sum(chunk1) + sum(chunk2) + sum(chunk3)` - iterative, variant - -**Important**: Float32 precision may be sufficient to make simple addition accumulation errors negligible, unlike multiply-accumulate in MatMul. +Replicating [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) with a key discovery about `nki.isa` operations. ## Key Findings -### 🎯 Core Principle: Reduction Dimension Tiling Creates Variance - -**Operations are naturally batch-invariant UNTIL:** - -1. ✅ You tile the **reduction dimension** (not parallelizable dimensions) -2. ✅ Tile size changes **dynamically** based on input characteristics -3. ✅ Operation uses **iterative accumulation** (not atomic reductions) - -**Examples:** -- ❌ **No variance**: RMSNorm batch tiling - tiles parallelizable dimension (batch) -- ✅ **Creates variance**: MatMul K tiling - tiles reduction dimension with accumulation -- ✅ **Creates variance**: RMSNorm split reduction - tiles hidden dimension with accumulation - -### 📊 Precision Effects: Amplification vs Erasure - -| Operation | float32 Error | bfloat16 Error | Amplification | Effect | -|-----------|---------------|----------------|---------------|--------| -| MatMul Lang (nl.matmul) | 0.000046 | 0.007812 | **170x** | Amplified | -| MatMul ISA (nisa.nc_matmul) | 0.000061 | 0.000000 | **0x** | Erased | -| RMSNorm Split (HIDDEN_TILE) | 0.000000 | 0.007812 | **21845x** | Amplified | - -**Critical Insight**: Bfloat16 has **two distinct behaviors** depending on quantization alignment: - -1. **Errors cross quantization boundaries**: Variance is **amplified** - - Lang MatMul: 0.000046 → 0.007812 (170x amplification) - - RMSNorm: 0.000000 → 0.007812 (21845x amplification) - - Different accumulation orders produce float32 values in different bfloat16 buckets - -2. **Errors stay within quantization boundaries**: Variance is **erased** - - ISA MatMul: 0.000061 → 0.000000 (quantization erasure) - - Different accumulation orders produce float32 values in the same bfloat16 bucket - - Variance becomes invisible in reduced precision - -**Why This Matters**: -- **Multiply-accumulate** (MatMul): Errors compound quickly, more likely to cross boundaries -- **Pure addition** (RMSNorm sum): Errors compound slowly, typically crosses boundaries -- **ISA-level operations**: Superior numerical stability produces errors that align better with quantization boundaries -- **Implication**: Kernel implementation quality determines whether bfloat16 amplifies or erases tiling variance through quantization alignment, not just error magnitude - -### 🔬 Replicating Paper Findings with NKI - -Our results directly replicate [Thinking Machines' findings](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/): - -**Paper's observation (CUDA):** -> "CUDA adapts K reduction strategy based on batch size, causing non-determinism" - -**Our NKI implementation:** -```python -# Batch-variant: Mimics CUDA's dynamic strategy -K_TILE = 64 if K <= 512 else 128 +### 1. Replicated the Paper: Batch Variance with `nki.lang` -# Batch-invariant: Fixed strategy (paper's solution) -K_TILE = 128 # Always -``` - -**Result**: Same variance pattern observed in NKI when we explicitly code dynamic tiling, confirming the paper's root cause analysis. - -### 🛡️ NKI's Natural Protection - -**Why NKI tends toward batch-invariance:** - -1. **Hardware constraints enforce constants** - - Tensor Engine limits: P-dim ≤ 128, free-dim ≤ 512 - - Encourages fixed compile-time tile sizes - - Makes dynamic adaptation less natural - -2. **Explicit control over tiling** - - Developers explicitly set K_TILE, HIDDEN_TILE, etc. - - No "magic" runtime optimization that varies strategy - - Batch-invariance is default unless explicitly coded otherwise - -3. **Atomic operations where possible** - - `nl.sum(entire_dimension)` is atomic - naturally invariant - - Only manual tiling creates variance - -4. **ISA-level numerical stability and quantization alignment** - - Low-level ISA instructions (`nisa.nc_matmul`) exhibit superior numerical precision - - Produces errors that align better with bfloat16 quantization boundaries - - Different tiling strategies may quantize to identical bfloat16 values, erasing variance - - Makes ISA kernels appear more robust to batch size variations in reduced precision - - However, variance still exists in float32—comprehensive testing in both precisions is essential - -## Implications for LLM Inference - -### ✅ Benefits - -1. **Deterministic inference** - Same outputs for temperature=0 sampling regardless of batch size -2. **On-policy RL** - Training and inference produce identical numerics -3. **Debugging** - Reproducible results across batch sizes simplifies debugging -4. **Cache coherence** - KV-cache values identical whether computed individually or batched - -### ⚠️ Requirements for Batch-Invariance +The paper showed CUDA operations aren't batch-invariant due to dynamic reduction strategies. **We replicated this in NKI using `nki.lang` kernels:** -1. **Fix reduction tile sizes** - ```python - # ❌ BAD: Dynamic tiling - K_TILE = 64 if K <= 512 else 128 - - # ✅ GOOD: Fixed tiling - K_TILE = 128 # Always - ``` +- **MatMul** (`nl.matmul`): Batch variance in both float32 and bfloat16 +- **RMSNorm**: Batch variance in both float32 and bfloat16 -2. **Use consistent precision** - - bfloat16 shows 157x larger variance than float32 - - Mixed precision can break invariance +### 2. Discovery: `nki.isa` Shows No Batch Variance in bfloat16 -3. **Avoid split reductions when possible** - - Prefer atomic reductions: `nl.sum(entire_dimension)` - - If split necessary, use fixed tile sizes +**Using `nki.isa` operations with the same dynamic reduction strategies:** -## Conclusion +- **MatMul** (`nisa.nc_matmul`): Variance in float32, but **NO variance in bfloat16** +- **RMSNorm** (ISA operations): Variance in float32, but **NO variance in bfloat16** -NKI naturally encourages batch-invariant implementations through: -- Hardware-enforced tile size constraints -- Explicit tiling control (no magic runtime optimization) -- Atomic reduction operations as primitives - -However, variance can still occur when: -- Manually implementing split reductions with dynamic tile sizes -- Using reduced precision (bfloat16) with iterative accumulation -- Adapting strategies based on input characteristics - -**Key findings that extend the Thinking Machines paper**: +## Results -1. **Batch variance stems from dynamic tiling of reduction dimensions** (confirmed) -2. **Fixed tiling strategies solve the problem** (confirmed) -3. **NEW: Quantization alignment effect** - Bfloat16 doesn't always amplify errors: - - When float32 differences cross quantization boundaries: Amplification occurs (170-21845x) - - When float32 differences stay within quantization boundaries: Variance is erased entirely - - ISA-level kernels with superior numerical stability produce errors that align better with boundaries - - This makes some implementations appear robust to batch variance in bfloat16, while variance still exists in float32 +| Operation | Kernel | bfloat16 | float32 | +|-----------|--------|----------|---------| +| **MatMul** | `nki.lang` | ✗ Variance | ✗ Variance | +| **MatMul** | `nki.isa` | ✓ **No Variance** | ✗ Variance | +| **RMSNorm** | `nki.lang` | ✗ Variance | ✗ Variance | +| **RMSNorm** | `nki.isa` | ✓ **No Variance** | ✗ Variance | -**Practical Implications**: -- High-quality kernel implementations (ISA-level) may hide batch variance in bfloat16 -- This can create false confidence—variance still exists in float32 -- Testing in float32 is essential to detect underlying numerical instability -- Don't rely on bfloat16 testing alone to validate batch invariance +**Implication**: Use `nki.isa` operations for deterministic bfloat16 inference. -## Running the Tests +## Running the Test ```bash cd contributed/batch_invariance python test_batch_invariance.py ``` -**Expected Output:** -``` -================================================================================ -Testing MatMul Correctness... - Lang kernel (nl.matmul): ✓ Matches PyTorch reference - ISA kernel (nisa.nc_matmul): ✓ Matches PyTorch reference - -================================================================================ -Testing MatMul batch variance (Lang kernel)... - Testing with float32: - Max difference between K_TILE strategies: 0.000046 - Results differ - Testing with bfloat16: - Max difference between K_TILE strategies: 0.007812 - Results differ - Precision impact: bfloat16 error is 170x larger than float32 - -================================================================================ -Testing MatMul batch variance (ISA kernel)... - Testing with float32: - Max difference: 0.000061 - Results differ - Testing with bfloat16: - Max difference: 0.000000 - Results identical - Precision impact: bfloat16 error is 0x smaller than float32 - Note: Float32 error (0.000061) is below bfloat16 quantization threshold (~0.008) - Quantization erases the difference rather than amplifying it - -================================================================================ -Testing RMSNorm batch invariance... - First 32 rows: batch=32 vs batch=128: MATCH ✓ - ✓ RMSNorm is batch-invariant! - Each row computed independently, reduction is atomic - -================================================================================ -Testing RMSNorm batch variance... - Max difference between HIDDEN_TILE strategies: 0.007812 - Results differ - ✗ Different HIDDEN_TILE sizes produce different results - -================================================================================ -Testing RMSNorm HIDDEN_TILE variance... - Testing with bfloat16: - Max difference between HIDDEN_TILE strategies: 0.007812 - Results differ - Testing with float32: - Max difference between HIDDEN_TILE strategies: 0.000000 - Results identical - Precision impact: bfloat16 error is 21845x larger than float32 - -================================================================================ -SUMMARY -MatMul & RMSNorm Batch Variance Results: -kernel float32_error bfloat16_error amplification -Lang (nl.matmul) 4.577637e-05 0.007812 170.666667 -ISA (nisa.nc_matmul) 6.103516e-05 0.000000 0.000000 -RMSNorm (HIDDEN_TILE) 3.576279e-07 0.007812 21845.333333 -``` +The test compares both kernel types with different K_TILE configurations and reports the differences in float32 vs bfloat16. ## Files -- `kernels/matmul_batch_invariant.py` - MatMul with configurable K_TILE -- `kernels/rmsnorm_batch_invariant.py` - Standard RMSNorm (atomic reduction) -- `kernels/rmsnorm_split_reduction.py` - RMSNorm with split reduction (demonstrates variance) -- `test_batch_invariance.py` - Comprehensive test suite +- `kernels/matmul_batch_invariant.py` - MatMul implementations (lang and ISA) +- `test_batch_invariance.py` - Test comparing both kernel types - `README.md` - This document ## References From 48ecf02e2179aa95e7f48ca13b64883c92a6f312 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Tue, 13 Jan 2026 15:15:26 -0500 Subject: [PATCH 19/21] Revise README for NKI Batch Invariance Study Updated the README to reflect a comprehensive study of batch invariance in NKI, detailing key findings, results, and implications for LLM inference. --- contributed/batch_invariance/README.md | 139 ++++++++++++++++++++----- 1 file changed, 115 insertions(+), 24 deletions(-) diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md index 57979e9..8fe2a01 100644 --- a/contributed/batch_invariance/README.md +++ b/contributed/batch_invariance/README.md @@ -1,51 +1,142 @@ -# NKI Batch Invariance: ISA vs Lang Kernels +# NKI Batch Invariance Study -Replicating [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) with a key discovery about `nki.isa` operations. +A comprehensive study of batch invariance in Neuron Kernel Interface (NKI), replicating and extending [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) research. + +## Overview + +This project demonstrates how different NKI kernel implementations (`nki.lang` vs `nki.isa`) exhibit varying degrees of batch invariance, particularly when using reduced precision formats like bfloat16. ## Key Findings -### 1. Replicated the Paper: Batch Variance with `nki.lang` +### 1. Batch Variance Occurs When Reduction Strategies Are Dynamic + +**Confirmed the core hypothesis**: Batch variance emerges when tile sizes for reduction dimensions are determined dynamically based on input shapes, exactly as described in the original paper. + +### 2. Precision Choice Dramatically Affects Variance Visibility + +Our testing revealed significant amplification effects: +- **MatMul (Lang)**: bfloat16 errors are **170x larger** than float32 +- **RMSNorm (Lang)**: bfloat16 errors are **21,845x larger** than float32 + +### 3. NKI ISA Operations Show Superior Batch Invariance + +**Critical Discovery**: `nki.isa` operations demonstrate batch invariance in bfloat16 precision where `nki.lang` operations show variance. + +| Operation | Kernel Type | float32 | bfloat16 | Amplification | +|-----------|-------------|---------|----------|---------------| +| **MatMul** | `nki.lang` | ✗ Variance (4.6e-05) | ✗ Variance (0.0078) | 170.7x | +| **MatMul** | `nki.isa` | ✗ Variance (6.1e-05) | ✅ **Invariant** (0.0000) | 0.0x | +| **RMSNorm** | `nki.lang` | ✗ Variance (3.6e-07) | ✗ Variance (0.0078) | 21,845x | +| **RMSNorm** | `nki.isa` | ✗ Variance (3.6e-07) | ✅ **Invariant** (0.0000) | 0.0x | -The paper showed CUDA operations aren't batch-invariant due to dynamic reduction strategies. **We replicated this in NKI using `nki.lang` kernels:** +### 4. NKI Design Patterns Naturally Promote Batch Invariance -- **MatMul** (`nl.matmul`): Batch variance in both float32 and bfloat16 -- **RMSNorm**: Batch variance in both float32 and bfloat16 +NKI best practices emphasize static tile sizes, which inherently avoid batch variance. However, the framework doesn't prevent variance when dynamic strategies are implemented. -### 2. Discovery: `nki.isa` Shows No Batch Variance in bfloat16 +## Technical Analysis -**Using `nki.isa` operations with the same dynamic reduction strategies:** +### Dynamic vs Static Tiling Strategies -- **MatMul** (`nisa.nc_matmul`): Variance in float32, but **NO variance in bfloat16** -- **RMSNorm** (ISA operations): Variance in float32, but **NO variance in bfloat16** +**Triton Split-K Approach** (Dynamic): +```python +num_pid_k ← tl.cdiv(k, block_k × split_k) # Shape-dependent +``` + +**NKI Standard Approach** (Static): +```python +# Fixed tile sizes regardless of input shape +TILES_IN_BLOCK_K = 4 # Static configuration +``` + +### Variance Demonstration -## Results +The same kernel with different K-tile configurations produces different results: -| Operation | Kernel | bfloat16 | float32 | -|-----------|--------|----------|---------| -| **MatMul** | `nki.lang` | ✗ Variance | ✗ Variance | -| **MatMul** | `nki.isa` | ✓ **No Variance** | ✗ Variance | -| **RMSNorm** | `nki.lang` | ✗ Variance | ✗ Variance | -| **RMSNorm** | `nki.isa` | ✓ **No Variance** | ✗ Variance | +```python +# Different K-blocking strategies → different accumulation order +result_1 = nki_matmul(lhs, rhs, TILES_IN_BLOCK_K=4) +result_2 = nki_matmul(lhs, rhs, TILES_IN_BLOCK_K=8) -**Implication**: Use `nki.isa` operations for deterministic bfloat16 inference. +# Results differ due to floating-point non-associativity +max_diff_bfloat16 = 4.000000 # Significant difference +max_diff_float32 = 0.000244 # Smaller but still present +``` + +## Experimental Results -## Running the Test +### Test Configuration +- **Matrix dimensions**: [256, 512] @ [512, 512] = [256, 512] +- **Precision formats**: float32, bfloat16 +- **Kernel variants**: Lang (`nl.matmul`, `nl.sum`) vs ISA (`nisa.nc_matmul`, `nisa.tensor_reduce`) + +### Batch Variance Summary + +``` + kernel float32_error bfloat16_error amplification + Lang (nl.matmul) 4.577637e-05 0.007812 170.666667 + ISA (nisa.nc_matmul) 6.103516e-05 0.000000 0.000000 + RMSNorm Lang (nl.sum) 3.576279e-07 0.007812 21845.333333 +RMSNorm ISA (nisa.tensor_reduce) 3.576279e-07 0.000000 0.000000 +``` + +## Implications for LLM Inference + +### For Deterministic Inference +- **Use `nki.isa` operations** when batch invariance is critical +- **Choose bfloat16 precision** with ISA kernels for deterministic results +- **Implement static tiling strategies** to avoid shape-dependent variance + +### For Performance vs Determinism Trade-offs +- `nki.lang` operations may offer performance benefits but sacrifice determinism +- `nki.isa` operations provide determinism at potential performance cost +- Precision choice significantly impacts the visibility of non-deterministic behavior + +## Running the Tests ```bash cd contributed/batch_invariance python test_batch_invariance.py ``` -The test compares both kernel types with different K_TILE configurations and reports the differences in float32 vs bfloat16. +### Expected Output +The test will show: +1. **Correctness verification**: Both kernels match PyTorch reference +2. **Batch variance analysis**: Comparison of different tiling strategies +3. **Precision impact**: Amplification effects between float32 and bfloat16 -## Files +## Project Structure -- `kernels/matmul_batch_invariant.py` - MatMul implementations (lang and ISA) -- `test_batch_invariance.py` - Test comparing both kernel types -- `README.md` - This document +``` +batch_invariance/ +├── README.md # This document +├── test_batch_invariance.py # Main test suite +└── kernels/ + ├── __init__.py + ├── matmul_batch_invariant.py # MatMul implementations (Lang & ISA) + └── rmsnorm_batch_invariant.py # RMSNorm implementations (Lang & ISA) +``` + +## Future Work + +1. **Batch Invariant Attention**: Implement attention mechanisms using ISA operations +2. **LLM Integration**: Compare standard NeuronLlama vs BatchInvariantLlama in full forward pass +3. **Performance Analysis**: Quantify performance trade-offs between Lang and ISA approaches +4. **Extended Precision Study**: Investigate other precision formats (fp16, int8) + +## Core Insight + +**Batch invariance is fundamentally a design choice, not a framework limitation.** While NKI's design patterns naturally encourage batch-invariant implementations through static tiling, the framework itself doesn't prevent variance when dynamic strategies are employed. + +The discovery that `nki.isa` operations maintain batch invariance in bfloat16 precision provides a clear path for deterministic LLM inference on Neuron hardware. ## References - [Thinking Machines: Defeating Nondeterminism in LLM Inference](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) +- [Thinking Machines GitHub: Batch Invariant Operations](https://github.com/thinking-machines-lab/batch_invariant_ops) +- [Meta: Triton Split-K Kernel Paper](https://scontent-dfw5-2.xx.fbcdn.net/v/t39.2365-6/418514147_782803483888724_2886980548537654804_n.pdf) - [AWS Neuron Documentation](https://awsdocs-neuron.readthedocs-hosted.com/) - [NKI Programming Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/) + +## Author + +Implementation and analysis by Josh Longenecker based on the foundational work by Thinking Machines Lab. From 9224692eff99556972aa10cfb3c592d0832a616d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 30 Jan 2026 21:52:32 +0000 Subject: [PATCH 20/21] disambiguate testing --- .../kernels/matmul_batch_invariant.py | 21 +- .../kernels/rmsnorm_batch_invariant.py | 24 +- .../batch_invariance/test_determinism.ipynb | 289 ++++++++++++++++++ 3 files changed, 312 insertions(+), 22 deletions(-) create mode 100644 contributed/batch_invariance/test_determinism.ipynb diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py index d957dae..f0dd39d 100644 --- a/contributed/batch_invariance/kernels/matmul_batch_invariant.py +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -11,12 +11,12 @@ @nki.compiler.skip_middle_end_transformations @nki.jit -def nki_matmul_kernel_isa(a, b, batch_invariant=True): +def nki_matmul_kernel_isa(a, b, deterministic=True): """ Matrix multiplication with batch invariance parameter - batch_invariant=True: Uses K_TILE=128 - batch_invariant=False: Dynamic K_TILE size used + deterministic=True: Uses K_TILE=128 + deterministic=False: Dynamic K_TILE size used This demonstrates how different K tiling affects numerical results. """ @@ -25,10 +25,10 @@ def nki_matmul_kernel_isa(a, b, batch_invariant=True): M_TILE = 128 # ONLY DIFFERENCE: K_TILE strategy - if batch_invariant: + if deterministic: K_TILE = 128 # Always hardcoded else: - K_TILE = 64 if K <= 512 else 128 # Adaptive + K_TILE = 64 if K <= 512 else 512 # Adaptive result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) @@ -67,13 +67,14 @@ def nki_matmul_kernel_isa(a, b, batch_invariant=True): return result +@nki.compiler.skip_middle_end_transformations @nki.jit -def nki_matmul_kernel_lang(a, b, batch_invariant=True): +def nki_matmul_kernel_lang(a, b, deterministic=True): """ Matrix multiplication with batch invariance parameter - batch_invariant=True: Uses K_TILE=128 - batch_invariant=False: Uses K_TILE=64 + deterministic=True: Uses K_TILE=128 + deterministic=False: Uses K_TILE=64 This demonstrates how different K tiling affects numerical results. """ @@ -82,10 +83,10 @@ def nki_matmul_kernel_lang(a, b, batch_invariant=True): M_TILE = 128 # ONLY DIFFERENCE: K_TILE strategy - if batch_invariant: + if deterministic: K_TILE = 128 # Always hardcoded else: - K_TILE = 64 if K <= 512 else 128 # Adaptive + K_TILE = 64 if K <= 512 else 512 # Adaptive result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py index f981514..c1bf25c 100644 --- a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -12,12 +12,12 @@ @nki.jit -def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, batch_invariant=True): +def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, deterministic=True): """ RMSNorm with split reduction along hidden dimension - batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) - batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) + deterministic=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + deterministic=False: HIDDEN_TILE=128 (more chunks, more accumulations) This demonstrates REAL batch variance because different tile sizes change the order of floating-point additions during reduction. @@ -33,10 +33,10 @@ def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, batch_invariant=True): # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) # Different sizes = different number of accumulations = variance! - if batch_invariant: - HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) + if deterministic: + HIDDEN_TILE = 128 # Fixed - same accumulation order always else: - HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + HIDDEN_TILE = min(64, hidden_dim) if hidden_dim <= 256 else (128 if hidden_dim <= 512 else 256) # Adaptive ix = nl.arange(BATCH_TILE)[:, None] iw = nl.arange(1)[:, None] @@ -101,12 +101,12 @@ def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, batch_invariant=True): @nki.compiler.skip_middle_end_transformations @nki.jit -def nki_rmsnorm_kernel_isa(a_tensor, g_tensor, batch_invariant=True): +def nki_rmsnorm_kernel_isa(a_tensor, g_tensor, deterministic=True): """ RMSNorm with split reduction along hidden dimension - batch_invariant=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) - batch_invariant=False: HIDDEN_TILE=128 (more chunks, more accumulations) + deterministic=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + deterministic=False: HIDDEN_TILE=128 (more chunks, more accumulations) This demonstrates REAL batch variance because different tile sizes change the order of floating-point additions during reduction. @@ -122,10 +122,10 @@ def nki_rmsnorm_kernel_isa(a_tensor, g_tensor, batch_invariant=True): # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) # Different sizes = different number of accumulations = variance! - if batch_invariant: - HIDDEN_TILE = 256 # Fewer chunks (e.g., 2 for hidden_dim=512) + if deterministic: + HIDDEN_TILE = 128 # Fixed - same accumulation order always else: - HIDDEN_TILE = 128 # More chunks (e.g., 4 for hidden_dim=512) + HIDDEN_TILE = min(64, hidden_dim) if hidden_dim <= 256 else (128 if hidden_dim <= 512 else 256) # Adaptive # Create indices for chunked tile ix, iy = nl.mgrid[0:BATCH_TILE, 0:HIDDEN_TILE] diff --git a/contributed/batch_invariance/test_determinism.ipynb b/contributed/batch_invariance/test_determinism.ipynb new file mode 100644 index 0000000..0407766 --- /dev/null +++ b/contributed/batch_invariance/test_determinism.ipynb @@ -0,0 +1,289 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ba410693", + "metadata": {}, + "outputs": [], + "source": [ + "from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_lang, nki_rmsnorm_kernel_isa\n", + "from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang\n", + "import torch\n", + "import torch_neuronx " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "17524879", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def test_determinism(kernel_fn, a, b, deterministic, iterations=1000):\n", + " \"\"\"Test kernel produces identical results across 1000 iterations.\"\"\"\n", + " ref = kernel_fn(a, b, deterministic=deterministic)\n", + " \n", + " for i in range(iterations):\n", + " result = kernel_fn(a, b, deterministic=deterministic)\n", + " max_diff = (result - ref).abs().max().item()\n", + " \n", + " if max_diff != 0:\n", + " print(f\" FAILED at iteration {i}: max_diff={max_diff}\")\n", + " return False\n", + " \n", + " print(f\" PASSED: {iterations} iterations identical\")\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3c0aaad", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2026-Jan-30 20:57:29.0402 9405:9453 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol\n", + "2026-Jan-30 20:57:29.0404 9405:9453 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2026-Jan-30 20:57:29.0406 9405:9453 [0] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed\n", + "2026-Jan-30 20:57:29.0408 9405:9453 [0] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + "Testing 1000 iterations...\n", + "\n", + "deterministic=True:\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 20:57:31.000481: 9405 [INFO]: Compilation Successfully Completed for model.MODULE_4621461744538923688+fad94d7c.hlo_module.pb\n", + " PASSED: 1000 iterations identical\n", + "\n", + "deterministic=False:\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 20:57:34.000204: 9405 [INFO]: Compilation Successfully Completed for model.MODULE_8088996852431820390+fad94d7c.hlo_module.pb\n", + " PASSED: 1000 iterations identical\n", + "\n", + "============================================================\n", + "deterministic=True: PASS\n", + "deterministic=False: PASS\n" + ] + } + ], + "source": [ + "device = 'xla'\n", + "K, M, N = 512, 256, 512\n", + "\n", + "A = torch.randn(K, M, device=device, dtype=torch.bfloat16)\n", + "B = torch.randn(K, N, device=device, dtype=torch.bfloat16)\n", + "\n", + "print(\"Testing 1000 iterations...\")\n", + "\n", + "print(\"\\ndeterministic=True:\")\n", + "pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=1000)\n", + "\n", + "print(\"\\ndeterministic=False:\")\n", + "pass_adp = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=False, iterations=1000)\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(f\"deterministic=True: {'PASS' if pass_det else 'FAIL'}\")\n", + "print(f\"deterministic=False: {'PASS' if pass_adp else 'FAIL'}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "62c20c1f", + "metadata": {}, + "outputs": [], + "source": [ + "def test_tiling_invariance(kernel_fn, is_isa=False, determinism=True, dtype=torch.bfloat16):\n", + " device = 'xla'\n", + " M, K, N = 512, 512, 512\n", + " \n", + " if is_isa:\n", + " # ISA expects [K, M] @ [K, N]\n", + " a = torch.linspace(-1, 1, K * M, device=device, dtype=dtype).reshape(K, M)\n", + " else:\n", + " # Lang expects [M, K] @ [K, N]\n", + " a = torch.linspace(-1, 1, M * K, device=device, dtype=dtype).reshape(M, K)\n", + " \n", + " b = torch.linspace(-1, 1, K * N, device=device, dtype=dtype).reshape(K, N)\n", + " \n", + " out_det = kernel_fn(a, b, deterministic=True) # K_TILE=128\n", + " out_adp = kernel_fn(a, b, deterministic=determinism) # K_TILE=64\n", + " \n", + " diff = (out_det - out_adp).abs().max().item()\n", + " \n", + " name = \"ISA\" if is_isa else \"Lang\"\n", + " print(f\"{name}: deterministic=True vs {determinism} → diff={diff:.6f}\")\n", + " print(f\" Tiling affects numerics: {'YES' if diff > 0 else 'NO'}\")\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "858001a6", + "metadata": {}, + "source": [ + "# Lang kernel deterministic vs non" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8e9bf743", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2026-Jan-30 21:50:02.0908 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol\n", + "2026-Jan-30 21:50:02.0911 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2026-Jan-30 21:50:02.0913 13220:13274 [1] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed\n", + "2026-Jan-30 21:50:02.0916 13220:13274 [1] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:04.000403: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_11522224973351651600+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:05.000978: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_7687714875879817323+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs False → diff=0.007812\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)\n", + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "612e5096", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lang: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:10.000417: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_6421119283783150616+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs False → diff=0.000046\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)\n", + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False, dtype=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "id": "8b375ee0", + "metadata": {}, + "source": [ + "# ISA kernel deterministic vs non" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ce21177c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2026-01-30 21:50:24.000003: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_5313299922059221254+fad94d7c/model.neff\n", + "ISA: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + "2026-01-30 21:50:24.000047: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_16718627453147721994+fad94d7c/model.neff\n", + "ISA: deterministic=True vs False → diff=0.000000\n", + " Tiling affects numerics: NO\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)\n", + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False)" + ] + }, + { + "cell_type": "markdown", + "id": "790c7628", + "metadata": {}, + "source": [ + "# ISA kernel deterministic vs non with float32" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "134ebb44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ISA: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + "2026-01-30 21:50:27.000813: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_11375411469173762114+fad94d7c/model.neff\n", + "ISA: deterministic=True vs False → diff=0.000061\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)\n", + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False, dtype=torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff6d3f27", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aws_neuronx_venv_pytorch_2_9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From ec03e6c363195fa422b9ab83e552a51428589f1a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 30 Jan 2026 21:55:41 +0000 Subject: [PATCH 21/21] disambiguate testing --- .../batch_invariance/test_determinism.ipynb | 30 +++++-------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/contributed/batch_invariance/test_determinism.ipynb b/contributed/batch_invariance/test_determinism.ipynb index 0407766..b70c999 100644 --- a/contributed/batch_invariance/test_determinism.ipynb +++ b/contributed/batch_invariance/test_determinism.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "id": "17524879", "metadata": {}, "outputs": [], @@ -48,29 +48,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "2026-Jan-30 20:57:29.0402 9405:9453 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol\n", - "2026-Jan-30 20:57:29.0404 9405:9453 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", - "2026-Jan-30 20:57:29.0406 9405:9453 [0] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed\n", - "2026-Jan-30 20:57:29.0408 9405:9453 [0] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", "Testing 1000 iterations...\n", "\n", "deterministic=True:\n", ".Completed run_backend_driver.\n", "\n", "Compiler status PASS\n", - "2026-01-30 20:57:31.000481: 9405 [INFO]: Compilation Successfully Completed for model.MODULE_4621461744538923688+fad94d7c.hlo_module.pb\n", - " PASSED: 1000 iterations identical\n", - "\n", - "deterministic=False:\n", - ".Completed run_backend_driver.\n", - "\n", - "Compiler status PASS\n", - "2026-01-30 20:57:34.000204: 9405 [INFO]: Compilation Successfully Completed for model.MODULE_8088996852431820390+fad94d7c.hlo_module.pb\n", - " PASSED: 1000 iterations identical\n", + "2026-01-30 21:55:07.000869: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_11646591744998724192+fad94d7c.hlo_module.pb\n", + " PASSED: 10000 iterations identical\n", "\n", "============================================================\n", - "deterministic=True: PASS\n", - "deterministic=False: PASS\n" + "deterministic=True: PASS\n" ] } ], @@ -81,17 +69,13 @@ "A = torch.randn(K, M, device=device, dtype=torch.bfloat16)\n", "B = torch.randn(K, N, device=device, dtype=torch.bfloat16)\n", "\n", - "print(\"Testing 1000 iterations...\")\n", + "print(\"Testing 10000 iterations...\")\n", "\n", "print(\"\\ndeterministic=True:\")\n", - "pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=1000)\n", - "\n", - "print(\"\\ndeterministic=False:\")\n", - "pass_adp = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=False, iterations=1000)\n", + "pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=10000)\n", "\n", "print(\"\\n\" + \"=\" * 60)\n", - "print(f\"deterministic=True: {'PASS' if pass_det else 'FAIL'}\")\n", - "print(f\"deterministic=False: {'PASS' if pass_adp else 'FAIL'}\")" + "print(f\"deterministic=True: {'PASS' if pass_det else 'FAIL'}\")" ] }, {