diff --git a/contributed/batch_invariance/README.md b/contributed/batch_invariance/README.md new file mode 100644 index 0000000..8fe2a01 --- /dev/null +++ b/contributed/batch_invariance/README.md @@ -0,0 +1,142 @@ +# NKI Batch Invariance Study + +A comprehensive study of batch invariance in Neuron Kernel Interface (NKI), replicating and extending [Thinking Machines' "Defeating Nondeterminism in LLM Inference"](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) research. + +## Overview + +This project demonstrates how different NKI kernel implementations (`nki.lang` vs `nki.isa`) exhibit varying degrees of batch invariance, particularly when using reduced precision formats like bfloat16. + +## Key Findings + +### 1. Batch Variance Occurs When Reduction Strategies Are Dynamic + +**Confirmed the core hypothesis**: Batch variance emerges when tile sizes for reduction dimensions are determined dynamically based on input shapes, exactly as described in the original paper. + +### 2. Precision Choice Dramatically Affects Variance Visibility + +Our testing revealed significant amplification effects: +- **MatMul (Lang)**: bfloat16 errors are **170x larger** than float32 +- **RMSNorm (Lang)**: bfloat16 errors are **21,845x larger** than float32 + +### 3. NKI ISA Operations Show Superior Batch Invariance + +**Critical Discovery**: `nki.isa` operations demonstrate batch invariance in bfloat16 precision where `nki.lang` operations show variance. + +| Operation | Kernel Type | float32 | bfloat16 | Amplification | +|-----------|-------------|---------|----------|---------------| +| **MatMul** | `nki.lang` | ✗ Variance (4.6e-05) | ✗ Variance (0.0078) | 170.7x | +| **MatMul** | `nki.isa` | ✗ Variance (6.1e-05) | ✅ **Invariant** (0.0000) | 0.0x | +| **RMSNorm** | `nki.lang` | ✗ Variance (3.6e-07) | ✗ Variance (0.0078) | 21,845x | +| **RMSNorm** | `nki.isa` | ✗ Variance (3.6e-07) | ✅ **Invariant** (0.0000) | 0.0x | + +### 4. NKI Design Patterns Naturally Promote Batch Invariance + +NKI best practices emphasize static tile sizes, which inherently avoid batch variance. However, the framework doesn't prevent variance when dynamic strategies are implemented. + +## Technical Analysis + +### Dynamic vs Static Tiling Strategies + +**Triton Split-K Approach** (Dynamic): +```python +num_pid_k ← tl.cdiv(k, block_k × split_k) # Shape-dependent +``` + +**NKI Standard Approach** (Static): +```python +# Fixed tile sizes regardless of input shape +TILES_IN_BLOCK_K = 4 # Static configuration +``` + +### Variance Demonstration + +The same kernel with different K-tile configurations produces different results: + +```python +# Different K-blocking strategies → different accumulation order +result_1 = nki_matmul(lhs, rhs, TILES_IN_BLOCK_K=4) +result_2 = nki_matmul(lhs, rhs, TILES_IN_BLOCK_K=8) + +# Results differ due to floating-point non-associativity +max_diff_bfloat16 = 4.000000 # Significant difference +max_diff_float32 = 0.000244 # Smaller but still present +``` + +## Experimental Results + +### Test Configuration +- **Matrix dimensions**: [256, 512] @ [512, 512] = [256, 512] +- **Precision formats**: float32, bfloat16 +- **Kernel variants**: Lang (`nl.matmul`, `nl.sum`) vs ISA (`nisa.nc_matmul`, `nisa.tensor_reduce`) + +### Batch Variance Summary + +``` + kernel float32_error bfloat16_error amplification + Lang (nl.matmul) 4.577637e-05 0.007812 170.666667 + ISA (nisa.nc_matmul) 6.103516e-05 0.000000 0.000000 + RMSNorm Lang (nl.sum) 3.576279e-07 0.007812 21845.333333 +RMSNorm ISA (nisa.tensor_reduce) 3.576279e-07 0.000000 0.000000 +``` + +## Implications for LLM Inference + +### For Deterministic Inference +- **Use `nki.isa` operations** when batch invariance is critical +- **Choose bfloat16 precision** with ISA kernels for deterministic results +- **Implement static tiling strategies** to avoid shape-dependent variance + +### For Performance vs Determinism Trade-offs +- `nki.lang` operations may offer performance benefits but sacrifice determinism +- `nki.isa` operations provide determinism at potential performance cost +- Precision choice significantly impacts the visibility of non-deterministic behavior + +## Running the Tests + +```bash +cd contributed/batch_invariance +python test_batch_invariance.py +``` + +### Expected Output +The test will show: +1. **Correctness verification**: Both kernels match PyTorch reference +2. **Batch variance analysis**: Comparison of different tiling strategies +3. **Precision impact**: Amplification effects between float32 and bfloat16 + +## Project Structure + +``` +batch_invariance/ +├── README.md # This document +├── test_batch_invariance.py # Main test suite +└── kernels/ + ├── __init__.py + ├── matmul_batch_invariant.py # MatMul implementations (Lang & ISA) + └── rmsnorm_batch_invariant.py # RMSNorm implementations (Lang & ISA) +``` + +## Future Work + +1. **Batch Invariant Attention**: Implement attention mechanisms using ISA operations +2. **LLM Integration**: Compare standard NeuronLlama vs BatchInvariantLlama in full forward pass +3. **Performance Analysis**: Quantify performance trade-offs between Lang and ISA approaches +4. **Extended Precision Study**: Investigate other precision formats (fp16, int8) + +## Core Insight + +**Batch invariance is fundamentally a design choice, not a framework limitation.** While NKI's design patterns naturally encourage batch-invariant implementations through static tiling, the framework itself doesn't prevent variance when dynamic strategies are employed. + +The discovery that `nki.isa` operations maintain batch invariance in bfloat16 precision provides a clear path for deterministic LLM inference on Neuron hardware. + +## References + +- [Thinking Machines: Defeating Nondeterminism in LLM Inference](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) +- [Thinking Machines GitHub: Batch Invariant Operations](https://github.com/thinking-machines-lab/batch_invariant_ops) +- [Meta: Triton Split-K Kernel Paper](https://scontent-dfw5-2.xx.fbcdn.net/v/t39.2365-6/418514147_782803483888724_2886980548537654804_n.pdf) +- [AWS Neuron Documentation](https://awsdocs-neuron.readthedocs-hosted.com/) +- [NKI Programming Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/) + +## Author + +Implementation and analysis by Josh Longenecker based on the foundational work by Thinking Machines Lab. diff --git a/contributed/batch_invariance/kernels/__init__.py b/contributed/batch_invariance/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contributed/batch_invariance/kernels/matmul_batch_invariant.py b/contributed/batch_invariance/kernels/matmul_batch_invariant.py new file mode 100644 index 0000000..f0dd39d --- /dev/null +++ b/contributed/batch_invariance/kernels/matmul_batch_invariant.py @@ -0,0 +1,115 @@ +""" +Batch-Invariant MatMul Kernel + +This kernel demonstrates batch invariance in matrix multiplication by controlling +the M-dimension tiling strategy. +""" + +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa + +@nki.compiler.skip_middle_end_transformations +@nki.jit +def nki_matmul_kernel_isa(a, b, deterministic=True): + """ + Matrix multiplication with batch invariance parameter + + deterministic=True: Uses K_TILE=128 + deterministic=False: Dynamic K_TILE size used + + This demonstrates how different K tiling affects numerical results. + """ + K, M = a.shape + N = b.shape[1] + M_TILE = 128 + + # ONLY DIFFERENCE: K_TILE strategy + if deterministic: + K_TILE = 128 # Always hardcoded + else: + K_TILE = 64 if K <= 512 else 512 # Adaptive + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Allocate and load a: [K_TILE, M_TILE] + i_a_p, i_a_f = nl.mgrid[0:K_TILE, 0:M_TILE] + a_tile = nl.ndarray((K_TILE, M_TILE), dtype=a.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=a[k*K_TILE + i_a_p, m*M_TILE + i_a_f], + dst=a_tile[i_a_p, i_a_f] + ) + + # Allocate and load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.ndarray((K_TILE, N), dtype=b.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=b[k*K_TILE + i_b_p, i_b_f], + dst=b_tile[i_b_p, i_b_f] + ) + + # Matmul + c_psum += nisa.nc_matmul(a_tile, b_tile) + + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nisa.dma_copy( + src=c_sbuf[i_out_p, i_out_f], + dst=result[m*M_TILE + i_out_p, i_out_f] + ) + + return result + +@nki.compiler.skip_middle_end_transformations +@nki.jit +def nki_matmul_kernel_lang(a, b, deterministic=True): + """ + Matrix multiplication with batch invariance parameter + + deterministic=True: Uses K_TILE=128 + deterministic=False: Uses K_TILE=64 + + This demonstrates how different K tiling affects numerical results. + """ + M, K = a.shape + N = b.shape[1] + M_TILE = 128 + + # ONLY DIFFERENCE: K_TILE strategy + if deterministic: + K_TILE = 128 # Always hardcoded + else: + K_TILE = 64 if K <= 512 else 512 # Adaptive + + result = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm) + + for m in nl.affine_range(M // M_TILE): + # Accumulator for this M chunk + c_psum = nl.zeros((M_TILE, N), dtype=nl.float32, buffer=nl.psum) + + # Reduction over K + for k in nl.affine_range(K // K_TILE): + # Load a: [M_TILE, K_TILE] + i_a_p, i_a_f = nl.mgrid[0:M_TILE, 0:K_TILE] + a_tile = nl.load(a[m*M_TILE + i_a_p, k*K_TILE + i_a_f]) + + # Load b: [K_TILE, N] + i_b_p, i_b_f = nl.mgrid[0:K_TILE, 0:N] + b_tile = nl.load(b[k*K_TILE + i_b_p, i_b_f]) + + # Matmul + c_psum += nl.matmul(a_tile, b_tile, transpose_x=False) + + # Store this M chunk + i_out_p, i_out_f = nl.mgrid[0:M_TILE, 0:N] + c_sbuf = nl.copy(c_psum, dtype=result.dtype) + nl.store(result[m*M_TILE + i_out_p, i_out_f], value=c_sbuf) + + return result diff --git a/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py new file mode 100644 index 0000000..c1bf25c --- /dev/null +++ b/contributed/batch_invariance/kernels/rmsnorm_batch_invariant.py @@ -0,0 +1,205 @@ +""" +RMSNorm to demonstrate Batch Variance + +This kernel tiles the HIDDEN DIMENSION (reduction axis) instead of just the batch dimension. +This creates different accumulation orders and breaks batch-invariance! +""" + +import math +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa + + +@nki.jit +def nki_rmsnorm_kernel_lang(a_tensor, g_tensor, deterministic=True): + """ + RMSNorm with split reduction along hidden dimension + + deterministic=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + deterministic=False: HIDDEN_TILE=128 (more chunks, more accumulations) + + This demonstrates REAL batch variance because different tile sizes + change the order of floating-point additions during reduction. + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + BATCH_TILE = 128 + + # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) + # Different sizes = different number of accumulations = variance! + if deterministic: + HIDDEN_TILE = 128 # Fixed - same accumulation order always + else: + HIDDEN_TILE = min(64, hidden_dim) if hidden_dim <= 256 else (128 if hidden_dim <= 512 else 256) # Adaptive + + ix = nl.arange(BATCH_TILE)[:, None] + iw = nl.arange(1)[:, None] + + # Process batch in tiles + for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): + # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks + # Use PSUM for accumulation (always float32 internally) + partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) + + # Iterate over hidden dimension in chunks + num_hidden_tiles = math.ceil(hidden_dim / HIDDEN_TILE) + for h in nl.affine_range(num_hidden_tiles): + h_start = h * HIDDEN_TILE + + # Create indices for this hidden chunk (always use full HIDDEN_TILE, mask later) + iy = nl.arange(HIDDEN_TILE)[None, :] + + # Create mask for valid hidden indices + valid_mask = ((i * BATCH_TILE + ix < num_rows) & + (h * HIDDEN_TILE + iy < hidden_dim)) + + # Load a CHUNK of the hidden dimension with proper indexing + a_chunk = nl.load(a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + mask=valid_mask) + + # Square this chunk + in_square_chunk = nl.square(a_chunk) + + # Reduce this chunk (sum along hidden dimension) + # Mask ensures we only sum valid elements + chunk_sum = nl.sum(in_square_chunk, axis=[1], keepdims=True, mask=valid_mask) + + # ACCUMULATE: This is where variance enters! + # Different HIDDEN_TILE sizes mean different number of additions + partial_square_sum += chunk_sum + + # Compute mean and RMS + mean = partial_square_sum / hidden_dim + rms_reciprocal = nl.rsqrt(mean) + + # Now load full row for normalization + iy_full = nl.arange(hidden_dim)[None, :] + a_tile = nl.load(a_tensor[i * BATCH_TILE + ix, iy_full], + mask=(i * BATCH_TILE + ix < num_rows)) + + # Normalize by RMS + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Apply weight + g_tile = nl.load(g_tensor.reshape((1, hidden_dim))[iw, iy_full]) + g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) + out_tile = nl.multiply(out_tile, g_bcast, + mask=(i * BATCH_TILE + ix < num_rows)) + + # Store result + nl.store(out_tensor[i * BATCH_TILE + ix, iy_full], value=out_tile, + mask=(i * BATCH_TILE + ix < num_rows)) + + return out_tensor + + +@nki.compiler.skip_middle_end_transformations +@nki.jit +def nki_rmsnorm_kernel_isa(a_tensor, g_tensor, deterministic=True): + """ + RMSNorm with split reduction along hidden dimension + + deterministic=True: HIDDEN_TILE=256 (fewer chunks, fewer accumulations) + deterministic=False: HIDDEN_TILE=128 (more chunks, more accumulations) + + This demonstrates REAL batch variance because different tile sizes + change the order of floating-point additions during reduction. + """ + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) + + assert a_tensor.shape[1] == g_tensor.shape[0] + + num_rows = a_tensor.shape[0] + hidden_dim = a_tensor.shape[1] + BATCH_TILE = 128 + + # CRITICAL: Tile size for REDUCTION dimension (hidden_dim) + # Different sizes = different number of accumulations = variance! + if deterministic: + HIDDEN_TILE = 128 # Fixed - same accumulation order always + else: + HIDDEN_TILE = min(64, hidden_dim) if hidden_dim <= 256 else (128 if hidden_dim <= 512 else 256) # Adaptive + + # Create indices for chunked tile + ix, iy = nl.mgrid[0:BATCH_TILE, 0:HIDDEN_TILE] + + # Create indices for full tile + ix_full, iy_full = nl.mgrid[0:BATCH_TILE, 0:hidden_dim] + + # Load weight once using nisa.dma_copy + iw, iy_g = nl.mgrid[0:1, 0:hidden_dim] + g_tile = nl.ndarray((1, hidden_dim), dtype=g_tensor.dtype, buffer=nl.sbuf) + nisa.dma_copy( + src=g_tensor.reshape((1, hidden_dim))[iw, iy_g], + dst=g_tile[iw, iy_g] + ) + + # Loop over batch dimension + for i in nl.affine_range(math.ceil(num_rows / BATCH_TILE)): + # SPLIT REDUCTION: Accumulate partial sums across hidden dimension chunks + partial_square_sum = nl.zeros((BATCH_TILE, 1), dtype=nl.float32, buffer=nl.psum) + + # Iterate over hidden dimension in chunks + for h in nl.affine_range(math.ceil(hidden_dim / HIDDEN_TILE)): + # Allocate buffer for chunk + a_chunk = nl.ndarray((BATCH_TILE, HIDDEN_TILE), dtype=a_tensor.dtype, buffer=nl.sbuf) + + # Load chunk with mask using nisa.dma_copy + nisa.dma_copy( + src=a_tensor[i * BATCH_TILE + ix, h * HIDDEN_TILE + iy], + dst=a_chunk[ix, iy], + mask=(i * BATCH_TILE + ix < num_rows) & (h * HIDDEN_TILE + iy < hidden_dim) + ) + + # Square this chunk + chunk_square = nl.square(a_chunk) + + # Reduce this chunk (sum along hidden dimension) using nisa.tensor_reduce + chunk_sum = nisa.tensor_reduce( + nl.add, + chunk_square[ix, iy], + axis=[1], + keepdims=True, + dtype=nl.float32 + ) + + # ACCUMULATE: This is where variance enters! + # Different HIDDEN_TILE sizes mean different number of additions + partial_square_sum += chunk_sum + + # Compute mean and RMS + mean = partial_square_sum * (1.0 / hidden_dim) + rms_reciprocal = nl.rsqrt(mean) + + # Allocate buffer for full tile + a_tile = nl.ndarray((BATCH_TILE, hidden_dim), dtype=a_tensor.dtype, buffer=nl.sbuf) + + # Load full row for normalization with mask using nisa.dma_copy + nisa.dma_copy( + src=a_tensor[i * BATCH_TILE + ix_full, iy_full], + dst=a_tile[ix_full, iy_full], + mask=(i * BATCH_TILE + ix_full < num_rows) + ) + + # Normalize by RMS + out_tile = nl.multiply(a_tile, rms_reciprocal) + + # Apply weight + g_bcast = g_tile.broadcast_to((BATCH_TILE, hidden_dim)) + out_tile = nl.multiply(out_tile, g_bcast, mask=(i * BATCH_TILE + ix_full < num_rows)) + + # Store result with mask using nisa.dma_copy + nisa.dma_copy( + src=out_tile[ix_full, iy_full], + dst=out_tensor[i * BATCH_TILE + ix_full, iy_full], + mask=(i * BATCH_TILE + ix_full < num_rows) + ) + + return out_tensor diff --git a/contributed/batch_invariance/test_batch_invariance.py b/contributed/batch_invariance/test_batch_invariance.py new file mode 100644 index 0000000..9223622 --- /dev/null +++ b/contributed/batch_invariance/test_batch_invariance.py @@ -0,0 +1,446 @@ +""" +Simple Batch Invariance Test +""" + +import torch +import time +import torch_neuronx +import numpy as np +from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_lang, nki_rmsnorm_kernel_isa +from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang + +# Prove that the kernels match pytorch and are functionally correct +def test_matmul_kernel_correctness(): + """ + Verify NKI matmul kernels produce correct results vs PyTorch. + + Validates mathematical correctness before analyzing batch invariance effects. + """ + print("Testing MatMul Correctness...") + device = 'xla' + + # Test dimensions + M, K, N = 256, 512, 512 + + print(f" Matrix dimensions: [{M}, {K}] @ [{K}, {N}] = [{M}, {N}]") + print() + + # Create test data + np.random.seed(42) + a_np = np.random.randn(M, K).astype(np.float32) + b_np = np.random.randn(K, N).astype(np.float32) + + # PyTorch reference (CPU) + a_torch = torch.tensor(a_np, dtype=torch.float32) + b_torch = torch.tensor(b_np, dtype=torch.float32) + + print(" Computing PyTorch reference (CPU)...") + start = time.time() + ref_output = torch.matmul(a_torch, b_torch) + ref_time = time.time() - start + print(f" Time: {ref_time:.6f}s") + print(f" Output shape: {ref_output.shape}") + print(f" First values: {ref_output[0, :5].numpy()}") + print() + + # Test Lang kernel - expects [M, K] @ [K, N] + print(" Testing Lang kernel (nl.matmul)...") + a_xla = torch.tensor(a_np, dtype=torch.float32, device=device) # [M, K] + b_xla = torch.tensor(b_np, dtype=torch.float32, device=device) # [K, N] + + start = time.time() + output_lang = nki_matmul_kernel_lang(a_xla, b_xla, batch_invariant=True) + lang_time = time.time() - start + + output_lang_cpu = output_lang.cpu() + print(f" Time: {lang_time:.6f}s") + print(f" Output shape: {output_lang_cpu.shape}") + print(f" First values: {output_lang_cpu[0, :5].numpy()}") + + lang_match = torch.allclose(ref_output, output_lang_cpu, atol=1e-4, rtol=1e-2) + max_diff_lang = torch.max(torch.abs(ref_output - output_lang_cpu)).item() + + if lang_match: + print(f" ✓ Matches PyTorch reference") + else: + print(f" ✗ Differs from PyTorch reference") + print(f" Max difference: {max_diff_lang:.6f}") + print() + + # Test ISA kernel - expects [K, M] @ [K, N] + print(" Testing ISA kernel (nisa.nc_matmul)...") + a_xla_t = torch.tensor(a_np.T, dtype=torch.float32, device=device) # [K, M] - transposed! + b_xla = torch.tensor(b_np, dtype=torch.float32, device=device) # [K, N] + + start = time.time() + output_isa = nki_matmul_kernel_isa(a_xla_t, b_xla, batch_invariant=True) + isa_time = time.time() - start + + output_isa_cpu = output_isa.cpu() + print(f" Time: {isa_time:.6f}s") + print(f" Output shape: {output_isa_cpu.shape}") + print(f" First values: {output_isa_cpu[0, :5].numpy()}") + + isa_match = torch.allclose(ref_output, output_isa_cpu, atol=1e-4, rtol=1e-2) + max_diff_isa = torch.max(torch.abs(ref_output - output_isa_cpu)).item() + + if isa_match: + print(f" ✓ Matches PyTorch reference") + else: + print(f" ✗ Differs from PyTorch reference") + print(f" Max difference: {max_diff_isa:.6f}") + print() + + # Summary + print("=" * 80) + if lang_match and isa_match: + print("✓ Both kernels produce correct results") + else: + print("✗ One or more kernels differ from PyTorch reference") + if not lang_match: + print(f" Lang kernel max error: {max_diff_lang:.6f}") + if not isa_match: + print(f" ISA kernel max error: {max_diff_isa:.6f}") + + assert lang_match, f"Lang kernel doesn't match PyTorch (max diff: {max_diff_lang})" + assert isa_match, f"ISA kernel doesn't match PyTorch (max diff: {max_diff_isa})" + +def test_matmul_isa(): + """ + ISA kernel K-tiling batch variance with quantization erasure. + + Expected: bfloat16 error = 0.0 despite float32 showing differences + Reason: nisa.nc_matmul produces float32 errors below bfloat16 threshold (~0.008) + Result: Demonstrates hardware-level numerical stability + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing MatMul batch variance (ISA kernel)...") + device = 'xla' + + K, N = 512, 512 + M_TILE = 128 + large_batch = 256 # 2x M_TILE + small_batch = 128 # 1x M_TILE + + print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") + print() + + # Create data ONCE in float32 - ISA kernel needs [K, M] layout! + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(K, large_batch).to(torch.float32) + b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) + + # Test with float32 FIRST + print(" Testing with float32:") + a_small_f32 = a_large_f32[:, :small_batch] # [K, 128] + + result_small_f32 = nki_matmul_kernel_isa(a_small_f32, b_f32, batch_invariant=True) + result_large_f32 = nki_matmul_kernel_isa(a_large_f32, b_f32, batch_invariant=False) + + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() + + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + b_bf16 = b_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:, :small_batch] + + result_small_bf16 = nki_matmul_kernel_isa(a_small_bf16, b_bf16, batch_invariant=True) + result_large_bf16 = nki_matmul_kernel_isa(a_large_bf16, b_bf16, batch_invariant=False) + + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + if diff_bf16 == 0.0: + print(f" Note: Float32 error ({diff_f32:.6f}) is below bfloat16 quantization threshold (~0.008)") + print(f" Quantization erases the difference rather than amplifying it") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "ISA (nisa.nc_matmul)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } + +def test_matmul_lang(): + """ + Lang kernel K-tiling batch variance with precision amplification. + + Expected: bfloat16 error ~170x larger than float32 + Reason: nl.matmul produces float32 errors above bfloat16 threshold + Result: Demonstrates how reduced precision amplifies tiling strategy effects + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing MatMul batch variance (Lang kernel)...") + device = 'xla' + + K, N = 512, 512 + M_TILE = 128 + large_batch = 256 # 2x M_TILE + small_batch = 128 # 1x M_TILE + + print(f" K={K} -> batch_invariant=True: K_TILE=128, batch_invariant=False: K_TILE=64") + print() + + # Create data ONCE in float32 - single source of truth + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * K, device=device).reshape(large_batch, K).to(torch.float32) + b_f32 = torch.linspace(-1, 1, K * N, device=device).reshape(K, N).to(torch.float32) + + # Test with float32 FIRST + print(" Testing with float32:") + # Test the SAME 128 rows in different batch contexts + a_small_f32 = a_large_f32[:small_batch, :] + + # Process as small batch (128 rows) + result_small_f32 = nki_matmul_kernel_lang(a_small_f32, b_f32, batch_invariant=True) + + # Process as part of large batch (256 rows) + result_large_f32 = nki_matmul_kernel_lang(a_large_f32, b_f32, batch_invariant=False) + + # Compare the SAME rows + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference between K_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() + + # Cast to bfloat16 from the SAME float32 source + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + b_bf16 = b_f32.to(torch.bfloat16) + + # Test the SAME 128 rows in different batch contexts + a_small_bf16 = a_large_bf16[:small_batch, :] + + # Process as small batch (128 rows) + result_small_bf16 = nki_matmul_kernel_lang(a_small_bf16, b_bf16, batch_invariant=True) + + # Process as part of large batch (256 rows) + result_large_bf16 = nki_matmul_kernel_lang(a_large_bf16, b_bf16, batch_invariant=False) + + # Compare the SAME rows + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between K_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x larger than float32") + print(f" This demonstrates how reduced precision amplifies tiling strategy effects") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "Lang (nl.matmul)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } + + + + +def test_rmsnorm_lang(): + """ + RMSNorm Lang kernel HIDDEN_TILE variance with precision effects. + + Uses nl.load, nl.store, nl.sum for data movement and reduction. + Different HIDDEN_TILE sizes create different reduction orders. + + Expected: Shows variance in both float32 and bfloat16 + + Returns: + dict: Test results with float32 and bfloat16 errors + """ + print("Testing RMSNorm batch variance (Lang kernel)...") + device = 'xla' + hidden_dim = 512 + large_batch = 128 + small_batch = 32 + + print(f" hidden_dim={hidden_dim}") + print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") + print(f" batch_invariant=False: HIDDEN_TILE=128 (4 chunks, 3 accumulations)") + print() + + # Create data ONCE in float32 + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.float32) + g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) + + # Test with float32 FIRST + print(" Testing with float32:") + a_small_f32 = a_large_f32[:small_batch, :] + + result_small_f32 = nki_rmsnorm_kernel_lang(a_small_f32, g_f32, batch_invariant=True) + result_large_f32 = nki_rmsnorm_kernel_lang(a_large_f32, g_f32, batch_invariant=False) + + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() + + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + g_bf16 = g_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:small_batch, :] + + result_small_bf16 = nki_rmsnorm_kernel_lang(a_small_bf16, g_bf16, batch_invariant=True) + result_large_bf16 = nki_rmsnorm_kernel_lang(a_large_bf16, g_bf16, batch_invariant=False) + + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 > 0: + ratio = diff_bf16 / diff_f32 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + print(f" Lang kernel shows variance due to different reduction chunking") + else: + ratio = 0.0 + print(f" Precision impact: N/A (no float32 difference detected)") + + return { + "kernel": "RMSNorm Lang (nl.sum)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } + + +def test_rmsnorm_isa(): + """ + RMSNorm ISA kernel demonstrates batch INVARIANCE. + + Uses nisa.dma_copy and nisa.tensor_reduce with skip_middle_end_transformations. + Despite different HIDDEN_TILE sizes, ISA produces identical results. + + Expected: No variance in either float32 or bfloat16 + Reason: ISA-level operations are deterministic regardless of tiling strategy + + Returns: + dict: Test results with float32 and bfloat16 errors (should be 0.0) + """ + print("Testing RMSNorm batch INVARIANCE (ISA kernel)...") + device = 'xla' + hidden_dim = 512 + large_batch = 128 + small_batch = 32 + + print(f" hidden_dim={hidden_dim}") + print(f" batch_invariant=True: HIDDEN_TILE=256 (2 chunks, 1 accumulation)") + print(f" batch_invariant=False: HIDDEN_TILE=128 (4 chunks, 3 accumulations)") + print(f" Note: ISA kernel uses @skip_middle_end_transformations") + print() + + # Create data ONCE in float32 + print(" Creating data in float32...") + a_large_f32 = torch.linspace(-1, 1, large_batch * hidden_dim, device=device).reshape(large_batch, hidden_dim).to(torch.float32) + g_f32 = torch.ones(hidden_dim, device=device, dtype=torch.float32) + + # Test with float32 FIRST + print(" Testing with float32:") + a_small_f32 = a_large_f32[:small_batch, :] + + result_small_f32 = nki_rmsnorm_kernel_isa(a_small_f32, g_f32, batch_invariant=True) + result_large_f32 = nki_rmsnorm_kernel_isa(a_large_f32, g_f32, batch_invariant=False) + + diff_f32 = torch.max(torch.abs(result_small_f32 - result_large_f32[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_f32:.6f}") + print(f" Results {'identical' if diff_f32 < 1e-6 else 'differ'}") + print() + + # Cast to bfloat16 + print(" Testing with bfloat16:") + a_large_bf16 = a_large_f32.to(torch.bfloat16) + g_bf16 = g_f32.to(torch.bfloat16) + a_small_bf16 = a_large_bf16[:small_batch, :] + + result_small_bf16 = nki_rmsnorm_kernel_isa(a_small_bf16, g_bf16, batch_invariant=True) + result_large_bf16 = nki_rmsnorm_kernel_isa(a_large_bf16, g_bf16, batch_invariant=False) + + diff_bf16 = torch.max(torch.abs(result_small_bf16 - result_large_bf16[:small_batch])).item() + print(f" Max difference between HIDDEN_TILE strategies: {diff_bf16:.6f}") + print(f" Results {'identical' if diff_bf16 < 1e-6 else 'differ'}") + print() + + if diff_f32 == 0.0 and diff_bf16 == 0.0: + print(f" ✓ ISA kernel is BATCH INVARIANT!") + print(f" @skip_middle_end_transformations ensures deterministic reduction") + print(f" regardless of HIDDEN_TILE size") + ratio = 0.0 + elif diff_f32 > 0: + ratio = diff_bf16 / diff_f32 if diff_f32 > 0 else 0.0 + print(f" Precision impact: bfloat16 error is {ratio:.2f}x {'larger' if diff_bf16 > diff_f32 else 'smaller'} than float32") + else: + ratio = 0.0 + print(f" Precision impact: N/A") + + return { + "kernel": "RMSNorm ISA (nisa.tensor_reduce)", + "float32_error": diff_f32, + "bfloat16_error": diff_bf16, + "amplification": ratio + } + + +if __name__ == "__main__": + import pandas as pd + + print("Batch Invariance Test") + print("=" * 80) + + # Run correctness test + test_matmul_kernel_correctness() + print("=" * 80) + + # Test Lang kernel + print("\nRunning Lang kernel test...") + lang_results = test_matmul_lang() + + print("=" * 80) + + # Test ISA kernel + print("\nRunning ISA kernel test...") + isa_results = test_matmul_isa() + + print("=" * 80) + + # Test RMSNorm Lang kernel + print("\nRunning RMSNorm Lang kernel test...") + rmsnorm_lang_results = test_rmsnorm_lang() + + print("=" * 80) + + # Test RMSNorm ISA kernel + print("\nRunning RMSNorm ISA kernel test...") + rmsnorm_isa_results = test_rmsnorm_isa() + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + + # Create results dataframe + print("\nBatch Variance Results:") + variance_df = pd.DataFrame([lang_results, isa_results, rmsnorm_lang_results, rmsnorm_isa_results]) + print(variance_df.to_string(index=False)) + print() diff --git a/contributed/batch_invariance/test_determinism.ipynb b/contributed/batch_invariance/test_determinism.ipynb new file mode 100644 index 0000000..b70c999 --- /dev/null +++ b/contributed/batch_invariance/test_determinism.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ba410693", + "metadata": {}, + "outputs": [], + "source": [ + "from kernels.rmsnorm_batch_invariant import nki_rmsnorm_kernel_lang, nki_rmsnorm_kernel_isa\n", + "from kernels.matmul_batch_invariant import nki_matmul_kernel_isa, nki_matmul_kernel_lang\n", + "import torch\n", + "import torch_neuronx " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "17524879", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def test_determinism(kernel_fn, a, b, deterministic, iterations=1000):\n", + " \"\"\"Test kernel produces identical results across 1000 iterations.\"\"\"\n", + " ref = kernel_fn(a, b, deterministic=deterministic)\n", + " \n", + " for i in range(iterations):\n", + " result = kernel_fn(a, b, deterministic=deterministic)\n", + " max_diff = (result - ref).abs().max().item()\n", + " \n", + " if max_diff != 0:\n", + " print(f\" FAILED at iteration {i}: max_diff={max_diff}\")\n", + " return False\n", + " \n", + " print(f\" PASSED: {iterations} iterations identical\")\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3c0aaad", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing 1000 iterations...\n", + "\n", + "deterministic=True:\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:55:07.000869: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_11646591744998724192+fad94d7c.hlo_module.pb\n", + " PASSED: 10000 iterations identical\n", + "\n", + "============================================================\n", + "deterministic=True: PASS\n" + ] + } + ], + "source": [ + "device = 'xla'\n", + "K, M, N = 512, 256, 512\n", + "\n", + "A = torch.randn(K, M, device=device, dtype=torch.bfloat16)\n", + "B = torch.randn(K, N, device=device, dtype=torch.bfloat16)\n", + "\n", + "print(\"Testing 10000 iterations...\")\n", + "\n", + "print(\"\\ndeterministic=True:\")\n", + "pass_det = test_determinism(nki_matmul_kernel_isa, A, B, deterministic=True, iterations=10000)\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(f\"deterministic=True: {'PASS' if pass_det else 'FAIL'}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "62c20c1f", + "metadata": {}, + "outputs": [], + "source": [ + "def test_tiling_invariance(kernel_fn, is_isa=False, determinism=True, dtype=torch.bfloat16):\n", + " device = 'xla'\n", + " M, K, N = 512, 512, 512\n", + " \n", + " if is_isa:\n", + " # ISA expects [K, M] @ [K, N]\n", + " a = torch.linspace(-1, 1, K * M, device=device, dtype=dtype).reshape(K, M)\n", + " else:\n", + " # Lang expects [M, K] @ [K, N]\n", + " a = torch.linspace(-1, 1, M * K, device=device, dtype=dtype).reshape(M, K)\n", + " \n", + " b = torch.linspace(-1, 1, K * N, device=device, dtype=dtype).reshape(K, N)\n", + " \n", + " out_det = kernel_fn(a, b, deterministic=True) # K_TILE=128\n", + " out_adp = kernel_fn(a, b, deterministic=determinism) # K_TILE=64\n", + " \n", + " diff = (out_det - out_adp).abs().max().item()\n", + " \n", + " name = \"ISA\" if is_isa else \"Lang\"\n", + " print(f\"{name}: deterministic=True vs {determinism} → diff={diff:.6f}\")\n", + " print(f\" Tiling affects numerics: {'YES' if diff > 0 else 'NO'}\")\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "858001a6", + "metadata": {}, + "source": [ + "# Lang kernel deterministic vs non" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8e9bf743", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2026-Jan-30 21:50:02.0908 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):219 CCOM WARN NET/OFI Failed to initialize rdma protocol\n", + "2026-Jan-30 21:50:02.0911 13220:13274 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2026-Jan-30 21:50:02.0913 13220:13274 [1] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed\n", + "2026-Jan-30 21:50:02.0916 13220:13274 [1] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:04.000403: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_11522224973351651600+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:05.000978: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_7687714875879817323+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs False → diff=0.007812\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)\n", + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "612e5096", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Lang: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + ".Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "2026-01-30 21:50:10.000417: 13220 [INFO]: Compilation Successfully Completed for model.MODULE_6421119283783150616+fad94d7c.hlo_module.pb\n", + "Lang: deterministic=True vs False → diff=0.000046\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False)\n", + "test_tiling_invariance(nki_matmul_kernel_lang, is_isa=False, determinism=False, dtype=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "id": "8b375ee0", + "metadata": {}, + "source": [ + "# ISA kernel deterministic vs non" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ce21177c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2026-01-30 21:50:24.000003: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_5313299922059221254+fad94d7c/model.neff\n", + "ISA: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + "2026-01-30 21:50:24.000047: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_16718627453147721994+fad94d7c/model.neff\n", + "ISA: deterministic=True vs False → diff=0.000000\n", + " Tiling affects numerics: NO\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)\n", + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False)" + ] + }, + { + "cell_type": "markdown", + "id": "790c7628", + "metadata": {}, + "source": [ + "# ISA kernel deterministic vs non with float32" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "134ebb44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ISA: deterministic=True vs True → diff=0.000000\n", + " Tiling affects numerics: NO\n", + "2026-01-30 21:50:27.000813: 13220 [INFO]: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.22.12471.0+b4a00d10/MODULE_11375411469173762114+fad94d7c/model.neff\n", + "ISA: deterministic=True vs False → diff=0.000061\n", + " Tiling affects numerics: YES\n" + ] + } + ], + "source": [ + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True)\n", + "test_tiling_invariance(nki_matmul_kernel_isa, is_isa=True, determinism=False, dtype=torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff6d3f27", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aws_neuronx_venv_pytorch_2_9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}