diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 22636828f9..0b37627901 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -60,6 +60,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" # Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc_deterministic.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 541ce9a8c2..0fa84551f9 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -14,13 +14,15 @@ mhc_fused_aggregate, mhc_fused_expand_combine, mhc_fused_projection, + mhc_generate_mix_and_aggregate, + is_deterministic_enforced, ) # Disable TF32 for matmul to ensure consistency between the fused and reference implementations torch.backends.cuda.matmul.allow_tf32 = False -def mhc_projection_ref(x, phi): +def mhc_projection_ref(x, phi, norm_weight): """ Reference operator for mHC's projection building operation. @@ -29,19 +31,20 @@ def mhc_projection_ref(x, phi): - phi_pre: (n, nC) - phi_post: (n, nC) - phi_res: (n^2, nC) + norm_weight: (nC,) or None, if not None, apply element-wise multiplication to phi before projection n: number of Hyper Connection streams C: hidden dimension per stream """ - x_dtype = x.dtype - x = x.to(torch.float32) - phi = phi.to(torch.float32) - - Hs = x @ phi.T # (M, 2n + n^2) - x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation + x_fp32 = x.to(torch.float32) ms = (x_fp32 * x_fp32).mean(dim=1) - return Hs.to(x_dtype), ms + phi_fp32 = phi.to(torch.float32) + if norm_weight is not None: + phi_fp32 = phi_fp32 * norm_weight.to(torch.float32)[None, :] + Hs = x_fp32 @ phi_fp32.T # (M, 2n + n^2) + + return Hs, ms def mhc_scale_ref(H, alpha, beta, ms, n): @@ -139,9 +142,9 @@ def mhc_aggregate_ref(x, H_pre, n): s, b, C, n = x.shape H_pre = H_pre.view(s, b, n, 1) - out = (x @ H_pre).view(s, b, C) + out = (x.to(H_pre.dtype) @ H_pre).view(s, b, C) - return out + return out.to(x.dtype) def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n): @@ -267,25 +270,42 @@ def get_tols(dtype): @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_projection(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight): reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n nC = n * C N = 2 * n + n * n - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) - ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref) - fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32) + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, norm_weight, use_tf32) fused_out_Hs = fused_out_Hs_padded[:, :N] torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols) @@ -295,10 +315,12 @@ def test_mhc_projection(cfg: MHCConfig, dtype): torch.testing.assert_close(x.grad, x_ref.grad, **tols) torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + if has_norm_weight: + torch.testing.assert_close(norm_weight.grad, norm_weight_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) def test_mhc_scale(cfg: MHCConfig, dtype): reset_rng_states() @@ -329,28 +351,39 @@ def test_mhc_scale(cfg: MHCConfig, dtype): torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward() torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols) + torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) - torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_combined(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_rmsnorm(cfg: MHCConfig, dtypes, has_norm_weight): + # Verify if the fused kernel is equivalent to applying RMSNorm in the normal order reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n N = 2 * n + n * n nC = n * C - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - - alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) - beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=phi_dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=phi_dtype) x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) @@ -358,8 +391,15 @@ def test_mhc_combined(cfg: MHCConfig, dtype): alpha_ref = alpha.detach().clone().requires_grad_(True) beta_ref = beta.detach().clone().requires_grad_(True) - ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref) - fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32) + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, norm_weight, use_tf32) ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref( ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n @@ -368,17 +408,19 @@ def test_mhc_combined(cfg: MHCConfig, dtype): fused_out_H_padded, alpha, beta, fused_out_r, n ) - def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): - dtype = x_ref.dtype - x_ref = x_ref.to(torch.float32) - phi_ref = phi_ref.to(torch.float32) - alpha_ref = alpha_ref.to(torch.float32) - beta_ref = beta_ref.to(torch.float32) - + def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref): # Check if after spliting RMSNorm to two steps in projection and scaling, - # theresult is close to applying RMSNorm in the correct order - x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,)) - H = x_rmsnorm @ phi_ref.T + # the result is close to applying RMSNorm in the correct order. + # Run RMSNorm in fp32 so the bf16 case has the same precision pattern as the + # kernel/ref (F.rms_norm on bf16 input would round x_rmsnorm back to bf16). + eps = torch.finfo(torch.float32).eps + norm_weight_fp32 = ( + norm_weight_ref.to(torch.float32) if norm_weight_ref is not None else None + ) + x_rmsnorm = F.rms_norm( + x_ref.to(torch.float32), normalized_shape=(nC,), weight=norm_weight_fp32, eps=eps + ) + H = x_rmsnorm @ phi_ref.T.to(torch.float32) H_pre = H[:, :n] H_post = H[:, n : 2 * n] H_res = H[:, 2 * n :] @@ -391,21 +433,88 @@ def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): out_post = 2 * out_post.sigmoid() out_res = out_res - return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype) + return out_pre, out_post, out_res # Return in FP32 to match the kernel's behavior combined_H_pre, combined_H_post, combined_H_res = mhc_combined( - x_ref, phi_ref, alpha_ref, beta_ref + x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref ) torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols) torch.testing.assert_close(combined_H_post, ref_H_post, **tols) torch.testing.assert_close(combined_H_res, ref_H_res, **tols) + torch.testing.assert_close(ref_H_pre, fused_H_pre, **tols) + torch.testing.assert_close(ref_H_post, fused_H_post, **tols) + torch.testing.assert_close(ref_H_res, fused_H_res, **tols) + torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols) torch.testing.assert_close(combined_H_post, fused_H_post, **tols) torch.testing.assert_close(combined_H_res, fused_H_res, **tols) +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype): + # Skip bf16 tests since in the unfused path the we accumulate 3 bf16 gradients, whereas in the fused path + # we accumulate 3 fp32 gradients and then cast to bf16 in the end, which causes two paths to have different precision patterns + + if not is_deterministic_enforced(): + pytest.skip( + "This test needs to be tested under deterministic mode to avoid discrepancies from" + " running two non-deterministic implementations twice" + ) + + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + nC = n * C + + # Since we tested the exactly same path twice with only gradient accumulation logic different, we can use tighter tolerances here + tols = {"atol": 1e-6, "rtol": 1e-6} + use_tf32 = False + + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + + def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): + aggregated, H_post, H_res = mhc_generate_mix_and_aggregate( + x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc + ) + H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n).view(s * b, n * n) + expanded_combined = mhc_fused_expand_combine( + aggregated, + None, + H_post, + x, + H_res, + False, + fused_grad_x_acc, + ) + + return expanded_combined + + expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, False) + expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, True) + + grad_output = torch.randn_like(expanded_combined_fuse_grad) + expanded_combined_fuse_grad.backward(grad_output) + expanded_combined_no_fuse_grad.backward(grad_output) + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) + torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) + + @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) @pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) @@ -482,7 +591,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias): H_res_ref = H_res.detach().clone().requires_grad_(True) ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n) - fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False) + fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, False, False) torch.testing.assert_close(fused_out, ref_out, **tols) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 965bb437ff..a36748409b 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -8,10 +8,27 @@ import itertools import os +import torch import triton import triton.language as tl +MAX_GRID_DIM_Y = 65535 # Maximum grid dimension in Y direction for current CUDA architectures + + +def align_to(x, alignment): + return ((x + alignment - 1) // alignment) * alignment + + +def get_device_sms(): + """ + Get the number of SMs of the current device. This is used to determine the grid size for launching Triton kernels. + """ + device_id = torch.cuda.current_device() + device_props = torch.cuda.get_device_properties(device_id) + sm_count = device_props.multi_processor_count + return sm_count + def projection_config_fwd(): block_m = [64, 128] @@ -24,39 +41,65 @@ def projection_config_fwd(): for m, bk, sk, w, s in itertools.product(block_m, block_k, step_k, warps, stages): configs.append( triton.Config( - {"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk, "STEP_SIZE_K": sk}, + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk, "STEP_SIZE_K": sk, "USE_SPLIT_K": True}, num_warps=w, num_stages=s, ) ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] return configs -def projection_config_bwd(): - block_m = [32, 128] - block_k = [128] - warps = [2] - stages = [2, 3, 4] - - configs = [] - for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) - ) +def projection_prune_fwd(configs, named_args, **kwargs): + DETERMINISTIC = named_args.get("DETERMINISTIC", kwargs.get("DETERMINISTIC", None)) + M = named_args.get("M", kwargs.get("M", None)) + K = named_args.get("K", kwargs.get("K", None)) + + block_m = [8, 16, 32, 64] + block_k = align_to(K, 32) + + # Use Split-K only if determinism is not enforced and M is not large enough to effectively parallelize + # sms * 4 is a empirical threshold I found via experiments on B200 for non-split-K starts to be better + if not DETERMINISTIC and triton.cdiv(M, block_m[0]) < get_device_sms() * 4: + pruned_configs = configs + else: + step_k = [32, 64, 128] + warps = [1, 2, 4] + stages = [2, 3, 4] + + pruned_configs = [] + for bm, sk, w, s in itertools.product(block_m, step_k, warps, stages): + pruned_configs.append( + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_K": block_k, + "STEP_SIZE_K": sk, + "USE_SPLIT_K": False, + }, + num_warps=w, + num_stages=s, + ) + ) + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs + pruned_configs = pruned_configs[:1] + return pruned_configs -@triton.autotune(configs=projection_config_fwd(), key=["M", "K"], reset_to_zero=["h_ptr", "ms_ptr"]) +@triton.autotune( + configs=projection_config_fwd(), + key=["M", "K", "DETERMINISTIC"], + reset_to_zero=["h_ptr", "ms_ptr"], + prune_configs_by={"early_config_prune": projection_prune_fwd}, +) @triton.jit def _mhc_projection_fwd_fused( x_ptr, # (M, K) phi_ptr, # (N, K) h_ptr, # (M, 32) ms_ptr, # (M,) + norm_weight_ptr, # (K,) M, N, K, @@ -67,12 +110,16 @@ def _mhc_projection_fwd_fused( stride_hm: tl.constexpr, stride_hn: tl.constexpr, stride_ms: tl.constexpr, + stride_norm_weight: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, STEP_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + HAS_NORM_WEIGHT: tl.constexpr, + DETERMINISTIC: tl.constexpr, # pylint: disable=unused-argument # If user wants to enforce deterministic, which is used to prune configs + USE_SPLIT_K: tl.constexpr, # If we actually use split-K, which is determined by both DETERMINISTIC flag and input size ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -86,8 +133,9 @@ def _mhc_projection_fwd_fused( tl.assume(stride_hm == 32) tl.assume(stride_hn == 1) tl.assume(stride_ms == 1) + tl.assume(stride_norm_weight == 1) - tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_M % 8 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) tl.assume(BLOCK_SIZE_N == 32) @@ -113,31 +161,72 @@ def _mhc_projection_fwd_fused( other=0.0, cache_modifier=".ca", ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) - ms_acc += tl.sum(x * x, axis=1) + + ms_acc += tl.sum(x.to(tl.float32) * x.to(tl.float32), axis=1) + + # In RMSNorm, mean square should be the mean squrare of the original x, so we need to first accumulate the sum of squares of x + # before we let x absore norm weight and pass x with norm weight's affine transformation applied to w to do the dot product + # to generate H. This is the correct way to fuse H = RMSNorm(x) @ phi.T. + if HAS_NORM_WEIGHT: + norm_weight_ptrs = norm_weight_ptr + k_offs * stride_norm_weight + norm_weight = tl.load( + norm_weight_ptrs, mask=mask_k, other=1.0, cache_modifier=".ca" + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) * norm_weight.to(tl.float32) h_acc = tl.dot( - x, tl.trans(phi, (1, 0)), h_acc, input_precision=precision, out_dtype=tl.float32 + x.to(phi.dtype), + tl.trans(phi, (1, 0)), + h_acc, + input_precision=precision, + out_dtype=tl.float32, ) h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn - tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed") + if USE_SPLIT_K: + tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed") + else: + tl.store(h_ptrs, h_acc, mask=mask_m[:, None]) offs_ms = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) masks_ms = offs_ms < M offs_ms %= M ms_ptrs = ms_ptr + offs_ms * stride_ms ms = ms_acc / tl.cast(K, tl.float32) - tl.atomic_add(ms_ptrs, ms, mask=masks_ms, sem="relaxed") + if USE_SPLIT_K: + tl.atomic_add(ms_ptrs, ms, mask=masks_ms, sem="relaxed") + else: + tl.store(ms_ptrs, ms, mask=masks_ms) + + +def projection_config_bwd_dx(): + block_m = [32, 128] + block_k = [128] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs @triton.autotune( - configs=projection_config_bwd(), + configs=projection_config_bwd_dx(), key=["M", "K"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], ) @triton.jit -def _mhc_projection_bwd_fused( +def _mhc_projection_bwd_fused_dx( x_ptr, grad_x_ptr, # (M, K) phi_ptr, # (N, K) + norm_weight_ptr, # (K,) grad_h_ptr, # (M, N) grad_ms_ptr, # (M,) M, @@ -149,6 +238,7 @@ def _mhc_projection_bwd_fused( stride_grad_xk: tl.constexpr, stride_phin, stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, stride_grad_phin, stride_grad_phik: tl.constexpr, stride_grad_hm: tl.constexpr, @@ -159,6 +249,8 @@ def _mhc_projection_bwd_fused( BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, + HAS_NORM_WEIGHT: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -174,6 +266,7 @@ def _mhc_projection_bwd_fused( tl.assume(stride_grad_phin == K) tl.assume(stride_grad_phik == 1) tl.assume(stride_grad_ms == 1) + tl.assume(stride_norm_weight == 1) tl.assume(BLOCK_SIZE_M % 32 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) @@ -204,27 +297,216 @@ def _mhc_projection_bwd_fused( phi = tl.load( phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + + if HAS_NORM_WEIGHT: + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load(norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".ca").to( + phi.dtype + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) * norm_weight.to(tl.float32)[None, :] + grad_ms = tl.load( grad_ms_ptrs, mask=offs_ms < M, other=0.0, cache_modifier=".ca" ) # (BLOCK_SIZE_M,) grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None] grad_x = tl.dot( - grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 + grad_h.to(phi.dtype), phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk grad_x = grad_x.to(x.dtype) + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) +def projection_config_bwd_dphi(): + block_m = [512, 1024, 2048] + step_m = [32] + block_k = [128, 256] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for bm, sm, bk, w, s in itertools.product(block_m, step_m, block_k, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "STEP_SIZE_M": sm, "BLOCK_SIZE_K": bk, "USE_SPLIT_M": True}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +def projection_prune_bwd_dphi(configs, named_args, **kwargs): + DETERMINISTIC = named_args.get("DETERMINISTIC", kwargs.get("DETERMINISTIC", None)) + M = named_args.get("M", kwargs.get("M", None)) + K = named_args.get("K", kwargs.get("K", None)) + + block_k = [128] + block_m = align_to(M, 128) + + # Use split-M only if determinism is not enforced and K is large enough to effectively parallelize + if not DETERMINISTIC and triton.cdiv(K, block_k[0]) < get_device_sms() * 4: + pruned_configs = configs + else: + step_m = [32] + warps = [4] + stages = [6, 7, 8] + + pruned_configs = [] + for bk, sm, w, s in itertools.product(block_k, step_m, warps, stages): + pruned_configs.append( + triton.Config( + { + "BLOCK_SIZE_M": block_m, + "STEP_SIZE_M": sm, + "BLOCK_SIZE_K": bk, + "USE_SPLIT_M": False, + }, + num_warps=w, + num_stages=s, + ) + ) + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + +@triton.autotune( + configs=projection_config_bwd_dphi(), + key=["M", "K", "DETERMINISTIC"], + reset_to_zero=["grad_phi_ptr", "grad_norm_weight_ptr"], + prune_configs_by={"early_config_prune": projection_prune_bwd_dphi}, +) +@triton.jit +def _mhc_projection_bwd_fused_dphi( + x_ptr, # (M, K) + grad_H_ptr, # (M, 32) + phi_ptr, # (N, K), N=24 in our case since n = 4 + norm_weight_ptr, # (K,) + grad_phi_ptr, # (N, K), N=24 in our case since n = 4 + grad_norm_weight_ptr, # (K,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_grad_Hm: tl.constexpr, + stride_grad_Hn: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, + stride_grad_phin, + stride_grad_phik: tl.constexpr, + stride_grad_norm_weight: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + STEP_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, + DETERMINISTIC: tl.constexpr, # pylint: disable=unused-argument # If user wants to enforce deterministic, which is used to prune configs + USE_SPLIT_M: tl.constexpr, # If we actually use split-M, which is determined by both DETERMINISTIC flag and input size +): + pid_k = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_grad_Hm == 32) + tl.assume(stride_grad_Hn == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_grad_phin == K) + tl.assume(stride_grad_phin == stride_phin) + tl.assume(stride_grad_phik == 1) + tl.assume(stride_grad_norm_weight == 1) + tl.assume(stride_norm_weight == 1) + + tl.assume(BLOCK_SIZE_M % 128 == 0) + tl.assume(BLOCK_SIZE_K % 64 == 0) + tl.assume(BLOCK_SIZE_N == 32) + tl.assume(STEP_SIZE_M % 32 == 0) + + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = offs_k < K + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_n = offs_n_full < N + + grad_psi_acc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + m_start = pid_m * BLOCK_SIZE_M + m_end = tl.minimum(m_start + BLOCK_SIZE_M, M) + for m_idx in range(0, tl.cdiv(m_end - m_start, STEP_SIZE_M)): + offs_m = m_start + m_idx * STEP_SIZE_M + tl.arange(0, STEP_SIZE_M) + mask_m = offs_m < M + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_K) + grad_H_ptrs = ( + grad_H_ptr + offs_m[:, None] * stride_grad_Hm + offs_n_full[None, :] * stride_grad_Hn + ) + grad_H = tl.load( + grad_H_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_N) + + grad_psi_acc = tl.dot( + tl.trans(grad_H, (1, 0)), + x.to(grad_H.dtype), + acc=grad_psi_acc, + out_dtype=tl.float32, + input_precision=precision, + ) + + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + offs_k[None, :] * stride_phik + phi = tl.load( + phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load( + norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".cg" + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) + norm_weight = norm_weight.to(tl.float32) + + # Keep grad_psi in SRAM and get grad_phi & grad_norm_weight + grad_phi = grad_psi_acc * norm_weight[None, :].to(grad_psi_acc.dtype) # (32, BLOCK_SIZE_K) + grad_norm_weight = tl.sum(grad_psi_acc * phi.to(grad_psi_acc.dtype), axis=0) # (BLOCK_SIZE_K,) + + grad_phi_ptrs = ( + grad_phi_ptr + offs_n_full[:, None] * stride_grad_phin + offs_k[None, :] * stride_grad_phik + ) + grad_norm_weight_ptrs = grad_norm_weight_ptr + offs_k * stride_grad_norm_weight + + if USE_SPLIT_M: + tl.atomic_add( + grad_phi_ptrs, + grad_phi, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + sem="relaxed", + ) + tl.atomic_add(grad_norm_weight_ptrs, grad_norm_weight, mask=mask_k, sem="relaxed") + else: + tl.store( + grad_phi_ptrs, grad_phi.to(phi.dtype), mask=(offs_n_full[:, None] < N) & mask_k[None, :] + ) + tl.store(grad_norm_weight_ptrs, grad_norm_weight.to(norm_weight.dtype), mask=mask_k) + + def scale_config(): - block_m = [128] warps = [4] stages = [1, 2, 4] configs = [] - for m, w, s in itertools.product(block_m, warps, stages): - configs.append(triton.Config({"BLOCK_SIZE_M": m}, num_warps=w, num_stages=s)) + for w, s in itertools.product(warps, stages): + configs.append(triton.Config({}, num_warps=w, num_stages=s)) if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": configs = configs[:1] @@ -324,23 +606,25 @@ def _mhc_scale_fwd_fused( def _mhc_scale_bwd_fused( grad_out_ptr, out_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension - grad_h_ptr, - h_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension + grad_H_ptr, + H_ptr, # (M, 2n + n^2), which is padded to (M, 32) in the last dimension grad_a_ptr, a_ptr, # (3,) grad_b_ptr, # (2n + n^2,) grad_ms_ptr, ms_ptr, # (M,) + ws_grad_a_ptr, # Temporary workspace for a with shape (NUM_SMS, 3), or None if DETERMINISTIC is False + ws_grad_b_ptr, # Temporary workspace for b with shape (NUM_SMS, 32), or None if DETERMINISTIC is False M, - n, + n: tl.constexpr, stride_grad_out_m, stride_grad_out_n, stride_out_m, stride_out_n, - stride_grad_hm, - stride_grad_hn, - stride_hm, - stride_hn, + stride_grad_Hm, + stride_grad_Hn, + stride_Hm, + stride_Hn, stride_grad_a, stride_a, stride_grad_b, @@ -349,6 +633,7 @@ def _mhc_scale_bwd_fused( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, eps: tl.constexpr, + DETERMINISTIC: tl.constexpr, ): pid = tl.program_id(0) @@ -358,10 +643,10 @@ def _mhc_scale_bwd_fused( tl.assume(stride_grad_out_n == 1) tl.assume(stride_out_m == 32) tl.assume(stride_out_n == 1) - tl.assume(stride_grad_hm == 32) - tl.assume(stride_grad_hn == 1) - tl.assume(stride_hm == 32) - tl.assume(stride_hn == 1) + tl.assume(stride_grad_Hm == 32) + tl.assume(stride_grad_Hn == 1) + tl.assume(stride_Hm == 32) + tl.assume(stride_Hn == 1) tl.assume(stride_grad_a == 1) tl.assume(stride_a == 1) tl.assume(stride_grad_b == 1) @@ -401,48 +686,66 @@ def _mhc_scale_bwd_fused( mask=mask_m[:, None] & mask_n[None, :], other=0.0, ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - h = tl.load( - h_ptr + offs_m[:, None] * stride_hm + cols[None, :] * stride_hn, + H = tl.load( + H_ptr + offs_m[:, None] * stride_Hm + cols[None, :] * stride_Hn, mask=mask_m[:, None] & mask_n[None, :], other=0.0, ) # (BLOCK_SIZE_M, BLOCK_SIZE_N) # Gradiient of H before H_pre and H_post go through sigmoid grad_out_out = grad_out * out - grad_h_pre = grad_out_out * (1 - out) - grad_h_post = grad_out_out * 0.5 * (2 - out) - grad_h = grad_out - grad_h = tl.where(cols[None, :] < n, grad_h_pre, grad_h) - grad_h = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_h_post, grad_h) - - grad_a = tl.sum(h * grad_h / rms[:, None], axis=0).to(a.dtype) - # Write grad_a[0:4].sum to grad_a_ptr[0], grad_a[4:8].sum to grad_a_ptr[1], and grad_a[8:24].sum to grad_a_ptr[2] - tl.atomic_add(grad_a_ptr, tl.where(cols[None, :] < n, grad_a, 0.0).sum(), sem="relaxed") - tl.atomic_add( - grad_a_ptr + stride_grad_a, - tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_a, 0.0).sum(), - sem="relaxed", - ) - tl.atomic_add( - grad_a_ptr + 2 * stride_grad_a, - tl.where((cols[None, :] >= 2 * n) & (cols[None, :] < 2 * n + n * n), grad_a, 0.0).sum(), - sem="relaxed", - ) - - grad_b = tl.sum(grad_h, axis=0).to(a.dtype) - tl.atomic_add(grad_b_ptr + cols * stride_grad_b, grad_b, mask=cols < N, sem="relaxed") - - grad_rms = (tl.sum((-grad_h * h * a[None, :]), axis=1) / (rms * rms)).to(rms.dtype) + grad_H_pre = grad_out_out * (1 - out) + grad_H_post = grad_out_out * 0.5 * (2 - out) + grad_H = grad_out + grad_H = tl.where(cols[None, :] < n, grad_H_pre, grad_H) + grad_H = tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_H_post, grad_H) + grad_H = grad_H.to(tl.float32) + H = H.to(tl.float32) + + grad_a = tl.sum(H * grad_H / rms[:, None], axis=0) + grad_b = tl.sum(grad_H, axis=0) + + grad_rms = (tl.sum((-grad_H * H * a[None, :]), axis=1) / (rms * rms)).to(rms.dtype) grad_ms = grad_rms / (2 * rms) tl.store(grad_ms_ptr + ms_offsets * stride_grad_ms, grad_ms, mask=ms_mask) - grad_h = a[None, :] * grad_h / rms[:, None] + grad_H = a[None, :] * grad_H / rms[:, None] tl.store( - grad_h_ptr + offs_m[:, None] * stride_grad_hm + cols[None, :] * stride_grad_hn, - grad_h, + grad_H_ptr + offs_m[:, None] * stride_grad_Hm + cols[None, :] * stride_grad_Hn, + grad_H, mask=mask_m[:, None] & mask_n[None, :], ) + if DETERMINISTIC: + ws_grad_a_ptrs = ws_grad_a_ptr + pid * 4 + # Write grad_a[0:4].sum to grad_a_ptr[0], grad_a[4:8].sum to grad_a_ptr[1], and grad_a[8:24].sum to grad_a_ptr[2] + tl.store(ws_grad_a_ptrs, tl.where(cols[None, :] < n, grad_a, 0.0).sum()) + tl.store( + ws_grad_a_ptrs + 1, + tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_a, 0.0).sum(), + ) + tl.store( + ws_grad_a_ptrs + 2, + tl.where((cols[None, :] >= 2 * n) & (cols[None, :] < 2 * n + n * n), grad_a, 0.0).sum(), + ) + ws_grad_b_ptrs = ws_grad_b_ptr + pid * 32 + cols + tl.store(ws_grad_b_ptrs, grad_b, mask=cols < N) + else: + # Write grad_a[0:4].sum to grad_a_ptr[0], grad_a[4:8].sum to grad_a_ptr[1], and grad_a[8:24].sum to grad_a_ptr[2] + tl.atomic_add(grad_a_ptr, tl.where(cols[None, :] < n, grad_a, 0.0).sum(), sem="relaxed") + tl.atomic_add( + grad_a_ptr + stride_grad_a, + tl.where((cols[None, :] >= n) & (cols[None, :] < 2 * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + tl.atomic_add( + grad_a_ptr + 2 * stride_grad_a, + tl.where((cols[None, :] >= 2 * n) & (cols[None, :] < 2 * n + n * n), grad_a, 0.0).sum(), + sem="relaxed", + ) + + tl.atomic_add(grad_b_ptr + cols * stride_grad_b, grad_b, mask=cols < N, sem="relaxed") + def sinkhorn_config(): block = [256, 1024] @@ -461,73 +764,7 @@ def sinkhorn_config(): key=["M"], ) @triton.jit -def _mhc_sinkhorn_fwd_fused_recompute( - x_ptr, # (M, n*n) - output_ptr, # (M, n*n) - stride_xm, - stride_xn, - stride_out_m, - stride_out_n, - M, - n: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - iters, -): - pid = tl.program_id(0) - - tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") - tl.assume(M > 0 and iters > 0) - tl.assume(n == 4) - - BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) - - offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) - offs_nn = tl.arange(0, n * n) - mask_batch = offs_batch < M - - x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn - x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) - x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) - - log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - - f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - - for _ in range(iters): - # Update f: logsumexp over the column dimension (1) - f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) - f_max = tl.max(f, axis=2) - f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns - f = log_mu - f - f_max - - # Update g: logsumexp over the row dimension (2) - g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) - g_max = tl.max(g, axis=1) - g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows - g = log_nu - g - g_max - - log_P = f[:, :, None] + x + g[:, None, :] - log_P = tl.reshape( - log_P, - ( - BATCH_SIZE, - n * n, - ), - ) - P = tl.exp(log_P) - - output_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n - tl.store(output_ptrs, P, mask=mask_batch[:, None]) - - -@triton.autotune( - configs=sinkhorn_config(), - key=["M"], -) -@triton.jit -def _mhc_sinkhorn_bwd_fused_recompute( +def _mhc_sinkhorn_bwd_fused( grad_out_ptr, output_ptr, grad_x_ptr, @@ -546,6 +783,7 @@ def _mhc_sinkhorn_bwd_fused_recompute( n: tl.constexpr, BLOCK_SIZE: tl.constexpr, iters, + RECOMPUTE: tl.constexpr, ): pid = tl.program_id(0) @@ -578,40 +816,41 @@ def _mhc_sinkhorn_bwd_fused_recompute( sbn = M * n - # Recompute the full history of f and g - log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + if RECOMPUTE: + # Recompute the full history of f and g + log_mu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + log_nu = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) + g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] - g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] - tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) - tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) - for iter_idx in range(iters): - # Update f: logsumexp over the column dimension (1) - f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) - f_max = tl.max(f, axis=2) - f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns - f = log_mu - f - f_max + for iter_idx in range(iters): + # Update f: logsumexp over the column dimension (1) + f = x + g[:, None, :] # Broadcast g to (BATCH_SIZE, n, n) + f_max = tl.max(f, axis=2) + f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns + f = log_mu - f - f_max - f_hist_ptrs = ( - hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - ) - tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) - # Update g: logsumexp over the row dimension (2) - g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) - g_max = tl.max(g, axis=1) - g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows - g = log_nu - g - g_max + # Update g: logsumexp over the row dimension (2) + g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) + g_max = tl.max(g, axis=1) + g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows + g = log_nu - g - g_max - g_hist_ptrs = ( - hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - ) - tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) # Backward pass grad_log_P = grad_out * P # (BATCH_SIZE, n, n) @@ -670,8 +909,8 @@ def _mhc_sinkhorn_bwd_fused_recompute( def _mhc_sinkhorn_fwd_fused( x_ptr, # (M, n*n) output_ptr, # (M, n*n) - hist_f_ptr, # (iters+1, M, n) - hist_g_ptr, # (iters+1, M, n) + hist_f_ptr, # (iters+1, M, n), or None if RECOMPUTE is True + hist_g_ptr, # (iters+1, M, n), or None if RECOMPUTE is True stride_xm, stride_xn, stride_out_m, @@ -680,6 +919,7 @@ def _mhc_sinkhorn_fwd_fused( n: tl.constexpr, BLOCK_SIZE: tl.constexpr, iters, + RECOMPUTE: tl.constexpr, ): pid = tl.program_id(0) @@ -704,13 +944,13 @@ def _mhc_sinkhorn_fwd_fused( f = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) g = tl.zeros((BATCH_SIZE, n), dtype=x.dtype) # (BATCH_SIZE, n) - sbn = M * n - - # Store the initial f and g to history - f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] - g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] - tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) - tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + if not RECOMPUTE: + sbn = M * n + # Store the initial f and g to history + f_hist_ptrs = hist_f_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + g_hist_ptrs = hist_g_ptr + offs_batch[:, None] * n + offs_n_hist[None, :] + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) for iter_idx in range(iters): # Update f: logsumexp over the column dimension (1) @@ -719,10 +959,11 @@ def _mhc_sinkhorn_fwd_fused( f = tl.log(tl.sum(tl.exp(f - f_max[:, :, None]), axis=2)) # logsumexp over columns f = log_mu - f - f_max - f_hist_ptrs = ( - hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - ) - tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) + if not RECOMPUTE: + f_hist_ptrs = ( + hist_f_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(f_hist_ptrs, f, mask=mask_batch[:, None]) # Update g: logsumexp over the row dimension (2) g = x + f[:, :, None] # Broadcast f to (BATCH_SIZE, n, n) @@ -730,10 +971,11 @@ def _mhc_sinkhorn_fwd_fused( g = tl.log(tl.sum(tl.exp(g - g_max[:, None, :]), axis=1)) # logsumexp over rows g = log_nu - g - g_max - g_hist_ptrs = ( - hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - ) - tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) + if not RECOMPUTE: + g_hist_ptrs = ( + hist_g_ptr + (iter_idx + 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] + ) + tl.store(g_hist_ptrs, g, mask=mask_batch[:, None]) log_P = f[:, :, None] + x + g[:, None, :] log_P = tl.reshape( @@ -749,130 +991,37 @@ def _mhc_sinkhorn_fwd_fused( tl.store(output_ptrs, P, mask=mask_batch[:, None]) +def aggregate_config_fwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + +def aggregate_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs + ) + ) + return pruned_configs + + @triton.autotune( - configs=sinkhorn_config(), - key=["M"], -) -@triton.jit -def _mhc_sinkhorn_bwd_fused( - grad_out_ptr, # (M, n*n) - output_ptr, # (M, n*n) - grad_x_ptr, # (M, n*n) - x_ptr, # (M, n*n) - hist_f_ptr, # (iters+1, M, n) - hist_g_ptr, # (iters+1, M, n) - stride_grad_out_m, - stride_grad_out_n, - stride_out_m, - stride_out_n, - stride_grad_xm, - stride_grad_xn, - stride_xm, - stride_xn, - M, - n: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - iters, -): - pid = tl.program_id(0) - - tl.static_assert(BLOCK_SIZE % (n * n) == 0, "BLOCK_SIZE must be divisible by n*n") - tl.assume(M > 0 and iters > 0) - tl.assume(n == 4) - - BATCH_SIZE: tl.constexpr = BLOCK_SIZE // (n * n) # Assume there's no remainder for simplicity - - offs_batch = pid * BATCH_SIZE + tl.arange(0, BATCH_SIZE) - offs_nn = tl.arange(0, n * n) - offs_n_hist = tl.arange(0, n) - mask_batch = offs_batch < M - - x_ptrs = x_ptr + offs_batch[:, None] * stride_xm + offs_nn[None, :] * stride_xn - x = tl.load(x_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) - x = tl.reshape(x, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) - - P_ptrs = output_ptr + offs_batch[:, None] * stride_out_m + offs_nn[None, :] * stride_out_n - P = tl.load(P_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) - P = tl.reshape(P, (BATCH_SIZE, n, n)) - - grad_out_ptrs = ( - grad_out_ptr - + offs_batch[:, None] * stride_grad_out_m - + offs_nn[None, :] * stride_grad_out_n - ) - grad_out = tl.load(grad_out_ptrs, mask=mask_batch[:, None], other=0.0) # (BATCH_SIZE, n*n) - grad_out = tl.reshape(grad_out, (BATCH_SIZE, n, n)) # (BATCH_SIZE, n, n) - - sbn = M * n - - # Backward pass - grad_log_P = grad_out * P # (BATCH_SIZE, n, n) - zeros = tl.zeros_like(grad_log_P) - grad_g = tl.sum(grad_log_P, axis=1) # (BATCH_SIZE, n) - grad_x = grad_log_P - - g_hist_ptrs = hist_g_ptr + iters * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - g = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) - g = tl.reshape(g, (BATCH_SIZE, n)) - - for iter_idx in range(iters, 0, -1): - f_hist_ptrs = hist_f_ptr + iter_idx * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - f = tl.load(f_hist_ptrs, mask=mask_batch[:, None], other=0.0) - f = tl.reshape(f, (BATCH_SIZE, n)) - - g_hist_ptrs = ( - hist_g_ptr + (iter_idx - 1) * sbn + offs_batch[:, None] * n + offs_n_hist[None, :] - ) - g_next = tl.load(g_hist_ptrs, mask=mask_batch[:, None], other=0.0) - g_next = tl.reshape(g_next, (BATCH_SIZE, n)) - - term_g = -grad_g[:, None, :] * tl.exp(f[:, :, None] + x + g[:, None, :]) - grad_f = tl.sum(term_g + grad_log_P, axis=2) # (BATCH_SIZE, n) - # Only the last iteration's f will contribute to gradients with both grad_g1 and grad_log_P - grad_log_P = zeros # Zero out grad_log_P for next iterations - - g = g_next - - term_f = -grad_f[:, :, None] * tl.exp(f[:, :, None] + x + g[:, None, :]) - grad_g = tl.sum(term_f, axis=1) # (BATCH_SIZE, n) - - grad_x += term_f + term_g - - grad_x_ptrs = ( - grad_x_ptr + offs_batch[:, None] * stride_grad_xm + offs_nn[None, :] * stride_grad_xn - ) - tl.store( - grad_x_ptrs, - tl.reshape( - grad_x, - ( - BATCH_SIZE, - n * n, - ), - ), - mask=mask_batch[:, None], - ) - - -def aggregate_config(): - block_m = [1, 2, 4] - block_c = [64, 128, 256] - warps = [1, 2, 4] - stages = [1, 2, 3, 4] - - configs = [] - for m, c, w, s in itertools.product(block_m, block_c, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) - ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs - - -@triton.autotune( - configs=aggregate_config(), - key=["M", "C"], + configs=aggregate_config_fwd(), + key=["M", "C"], + prune_configs_by={"early_config_prune": aggregate_prune_fwd}, ) @triton.jit def _mhc_aggregate_fwd( @@ -949,7 +1098,79 @@ def _mhc_aggregate_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_c[None, :]) -@triton.autotune(configs=aggregate_config(), key=["M", "C"], reset_to_zero=["grad_H_pre_ptr"]) +def aggregate_config_bwd(): + block_m = [1, 2, 4] + block_c = [64, 128, 256] + step_c = [32, 64] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for bm, bc, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "BLOCK_SIZE_C": bc, "STEP_SIZE_C": sc, "USE_SPLIT_C": True}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +def aggregate_prune_bwd(configs, named_args, **kwargs): + DETERMINISTIC = named_args.get("DETERMINISTIC", kwargs.get("DETERMINISTIC", None)) + M = named_args.get("M", kwargs.get("M", None)) + C = named_args.get("C", kwargs.get("C", None)) + + block_m = [4] + block_c = align_to(C, 64) + + # Use Split-K only if determinism is not enforced and M is not large enough to effectively parallelize + if not DETERMINISTIC and triton.cdiv(M, block_m[0]) < get_device_sms() * 4: + pruned_configs = configs + else: + step_c = [64] + warps = [1] + stages = [2, 3, 4] + + pruned_configs = [] + for bm, sc, w, s in itertools.product(block_m, step_c, warps, stages): + pruned_configs.append( + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_C": block_c, + "STEP_SIZE_C": sc, + "USE_SPLIT_C": False, + }, + num_warps=w, + num_stages=s, + ) + ) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + pruned_configs, + ) + ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + +@triton.autotune( + configs=aggregate_config_bwd(), + key=["M", "C", "DETERMINISTIC"], + reset_to_zero=["grad_H_pre_ptr"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], + prune_configs_by={"early_config_prune": aggregate_prune_bwd}, +) @triton.jit def _mhc_aggregate_bwd( grad_output_ptr, # (M, C) @@ -969,7 +1190,11 @@ def _mhc_aggregate_bwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, + DETERMINISTIC: tl.constexpr, # pylint: disable=unused-argument # If user wants to enforce deterministic, which is used to prune configs + USE_SPLIT_C: tl.constexpr, ): """ Forward: @@ -992,38 +1217,14 @@ def _mhc_aggregate_bwd( tl.assume(stride_grad_output_m > 0 and stride_grad_output_c == 1) tl.assume(BLOCK_SIZE_C % 32 == 0) + tl.assume(STEP_SIZE_C % 32 == 0) + tl.assume(BLOCK_SIZE_C % STEP_SIZE_C == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - grad_output_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_c[None, :] * stride_grad_output_c - ) - grad_output = tl.load( - grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - - grad_H_pre = tl.dot( - tl.reshape(grad_output, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) - grad_H_pre = tl.reshape(grad_H_pre, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre - tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n H_pre_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_pre = tl.load( @@ -1031,19 +1232,62 @@ def _mhc_aggregate_bwd( ) # (BLOCK_SIZE_M * n) H_pre = tl.reshape(H_pre, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) + grad_H_pre_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + grad_output_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_c[None, :] * stride_grad_output_c + ) + grad_output = tl.load( + grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C) - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store( - grad_x_ptrs, - grad_x, - mask=mask_m[:, None] & mask_cn[None, :], - ) + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + + grad_H_pre_acc = tl.dot( + tl.reshape(grad_output, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_pre_acc, + input_precision=precision, + out_dtype=tl.float32, + ) + + # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n)) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc + tl.store( + grad_x_ptrs, + grad_x, + mask=mask_m[:, None] & mask_cn[None, :], + ) + grad_H_pre = tl.reshape(grad_H_pre_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre + if USE_SPLIT_C: + tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + else: + tl.store(grad_H_pre_ptrs, grad_H_pre.to(H_pre.dtype), mask=offs_grad_H_pre < M * n) -def expand_combine_config(): + +def expand_combine_config_fwd(): block_m = [1, 2, 4] block_c = [128, 256] warps = [1, 2] @@ -1059,13 +1303,26 @@ def expand_combine_config(): return configs +def expand_combine_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs + ) + ) + return pruned_configs + + @triton.autotune( - configs=expand_combine_config(), + configs=expand_combine_config_fwd(), key=["M", "C"], + prune_configs_by={"early_config_prune": expand_combine_prune_fwd}, ) @triton.jit def _mhc_expand_combine_fwd( f_ptr, # (M, C) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) @@ -1075,6 +1332,7 @@ def _mhc_expand_combine_fwd( n: tl.constexpr, stride_fm, stride_fc, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_output_m, @@ -1082,9 +1340,10 @@ def _mhc_expand_combine_fwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ - output = f @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) """ pid_m = tl.program_id(1) @@ -1095,6 +1354,7 @@ def _mhc_expand_combine_fwd( tl.assume(C > 0) tl.assume(n == 4) tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) tl.assume(stride_xm > 0 and stride_xCn == 1) tl.assume(stride_output_m > 0 and stride_output_Cn == 1) @@ -1109,6 +1369,8 @@ def _mhc_expand_combine_fwd( f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + if HAS_BIAS: + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load( @@ -1116,10 +1378,12 @@ def _mhc_expand_combine_fwd( ) H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post: + # Residual connection path: res_out = f @ H_post + bias @ H_post: # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) # Due to broadcasting, it's equivalent to a multiplicaiton out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + if HAS_BIAS: + out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) @@ -1167,332 +1431,89 @@ def _mhc_expand_combine_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], - reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr"], -) -@triton.jit -def _mhc_expand_combine_bwd( - grad_output_ptr, # (M, C, n) - f_ptr, # (M, C) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - grad_H_post_ptr, # (M, n) - grad_f_ptr, # (M, C) - grad_H_res_ptr, # (M, n, n) - grad_x_ptr, # (M, C, n) - M, - C, - n: tl.constexpr, - stride_grad_output_m, - stride_grad_output_Cn, - stride_fm, - stride_fc, - stride_xm, - stride_xCn, - stride_grad_fm, - stride_grad_fc, - stride_grad_xm, - stride_grad_xCn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, - precision: tl.constexpr, -): - """ - Each block - It reads - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module - - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input - - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection - and writes - - (BLOCK_SIZE_M, n) of grad_H_post - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f - - (BLOCK_SIZE_M, n, n) of grad_H_res - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x - - Forward: - out = f @ H_post + x @ H_res - Backward: - GEMM: - grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - Not GEMM: - grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) - grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) - tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) - tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - - H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 - ) # (BLOCK_SIZE_M, n, n) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) - offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res - tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" - ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # grad_f = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - - -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], -) -@triton.jit -def _mhc_expand_combine_with_bias_fwd( - f_ptr, # (M, C) - bias_ptr, # (C,) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - output_ptr, # # (M, C, n) - M, - C, - n: tl.constexpr, - stride_fm, - stride_fc, - stride_bias, - stride_xm, - stride_xCn, - stride_output_m, - stride_output_Cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, -): - """ - output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_bias == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_output_m > 0 and stride_output_Cn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n +def expand_combine_config_bwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + step_c = [32, 64] + warps = [1, 2] + stages = [1, 2, 3, 4] - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + configs = [] + for m, c, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c, "STEP_SIZE_C": sc, "USE_SPLIT_C": True}, + num_warps=w, + num_stages=s, + ) + ) + return configs - offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load( - H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" - ) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post + bias @ H_post: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # Due to broadcasting, it's equivalent to a multiplicaiton - out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) - out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) +def expand_combine_prune_bwd(configs, named_args, **kwargs): + DETERMINISTIC = named_args.get("DETERMINISTIC", kwargs.get("DETERMINISTIC", None)) + M = named_args.get("M", kwargs.get("M", None)) + C = named_args.get("C", kwargs.get("C", None)) + + block_m = [4] + block_c = align_to(C, 32) + + # Use Split-K only if determinism is not enforced and M is not large enough to effectively parallelize + # sms * 8 is a empirical threshold I found via experiments on B200 for non-split-K starts to be better + if not DETERMINISTIC and triton.cdiv(M, block_m[0]) < get_device_sms() * 8: + pruned_configs = configs + else: + step_c = [32, 64, 128] + warps = [1, 2] + stages = [1, 2, 3, 4] + + pruned_configs = [] + for bm, sc, w, s in itertools.product(block_m, step_c, warps, stages): + pruned_configs.append( + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_C": block_c, + "STEP_SIZE_C": sc, + "USE_SPLIT_C": False, + }, + num_warps=w, + num_stages=s, + ) + ) - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + pruned_configs, + ) ) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # Manifold connection path: manifold_out = H_res @ x: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: - # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] - # + x[:, :, 1] @ H_res[:, 1, :] - # + x[:, :, 2] @ H_res[:, 2, :] - # + x[:, :, 3] @ H_res[:, 3, :] - x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) - x01, x23 = tl.split( - x_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) - H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) - out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) - out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) - out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) - - out = out_acc.to(x.dtype) - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - output_ptrs = ( - output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn - ) - tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs @triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], + configs=expand_combine_config_bwd(), + key=["M", "C", "DETERMINISTIC"], reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr", "grad_bias_ptr"], + prune_configs_by={"early_config_prune": expand_combine_prune_bwd}, ) @triton.jit -def _mhc_expand_combine_with_bias_bwd( +def _mhc_expand_combine_bwd( grad_output_ptr, # (M, C, n) f_ptr, # (M, C) - bias_ptr, # (C,) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) grad_H_post_ptr, # (M, n) grad_f_ptr, # (M, C) - grad_bias_ptr, # (C,) + grad_bias_ptr, # (C,), or None if HAS_BIAS is False + grad_bias_ws_ptr, # (M // BLOCK_SIZE_M, C), or None if HAS_BIAS is False or DETERMINISTIC is False grad_H_res_ptr, # (M, n, n) grad_x_ptr, # (M, C, n) M, @@ -1502,18 +1523,24 @@ def _mhc_expand_combine_with_bias_bwd( stride_grad_output_Cn, stride_fm, stride_fc, - stride_bias, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_grad_fm, stride_grad_fc, - stride_grad_bias, + stride_grad_bias, # Not used if HAS_BIAS is False + stride_grad_bias_ws_m, # Not used if HAS_BIAS is False or DETERMINISTIC is False + stride_grad_bias_ws_c, # Not used if HAS_BIAS is False or DETERMINISTIC is False stride_grad_xm, stride_grad_xCn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + HAS_BIAS: tl.constexpr, + DETERMINISTIC: tl.constexpr, + USE_SPLIT_C: tl.constexpr, ): """ Each block @@ -1552,142 +1579,184 @@ def _mhc_expand_combine_with_bias_bwd( tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) tl.assume(stride_grad_bias == 1) + tl.assume(stride_grad_bias_ws_m == C) + tl.assume(stride_grad_bias_ws_c == 1) tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) tl.assume(BLOCK_SIZE_C % 32 == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + grad_H_post_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + grad_H_res_acc = tl.zeros((BLOCK_SIZE_M, n, n), dtype=tl.float32) H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + H_post_reshape = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post_reshape) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) H_res = tl.load( H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 ) # (BLOCK_SIZE_M, n, n) H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + if HAS_BIAS: + bias = tl.load( + bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0 + ) # (STEP_SIZE_C,) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post_acc = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + if HAS_BIAS: + grad_H_post_acc = tl.dot( + tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res_acc = tl.dot( + tl.trans(x, (0, 2, 1)), + grad_out, + acc=grad_H_res_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, n, n) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2), (BLOCK_SIZE_M, STEP_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, STEP_SIZE_C) = (BLOCK_SIZE_M, 1, STEP_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, STEP_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, STEP_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = ( + grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + ) + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + if HAS_BIAS: + grad_bias = tl.sum(grad_f_acc, axis=0) # (STEP_SIZE_C,) + if DETERMINISTIC: + grad_bias_ws_ptrs = ( + grad_bias_ws_ptr + + pid_m * stride_grad_bias_ws_m + + offs_c * stride_grad_bias_ws_c + ) + tl.store(grad_bias_ws_ptrs, grad_bias, mask=mask_c) + else: + grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias + tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, STEP_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, STEP_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape( + grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - acc=grad_H_post, - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + grad_H_post = tl.reshape(grad_H_post_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + grad_H_res = tl.reshape(grad_H_res_acc, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res - tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" - ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - grad_bias = tl.sum(grad_f_acc, axis=0) # (BLOCK_SIZE_C,) - grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias - tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) + if USE_SPLIT_C: + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") + tl.atomic_add( + grad_H_res_ptrs, + grad_H_res.to(tl.float32), + mask=offs_grad_H_res < M * n * n, + sem="relaxed", + ) + else: + tl.store(grad_H_post_ptrs, grad_H_post.to(H_post.dtype), mask=offs_grad_H_post < M * n) + tl.store(grad_H_res_ptrs, grad_H_res.to(H_res.dtype), mask=offs_grad_H_res < M * n * n) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 987216e327..c8d61b5cd1 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -9,38 +9,119 @@ import triton from transformer_engine.common.triton.mhc import ( + _mhc_projection_bwd_fused_dphi, + _mhc_projection_bwd_fused_dx, _mhc_scale_fwd_fused, _mhc_scale_bwd_fused, - _mhc_expand_combine_with_bias_fwd, - _mhc_expand_combine_with_bias_bwd, _mhc_expand_combine_fwd, _mhc_expand_combine_bwd, _mhc_aggregate_fwd, _mhc_aggregate_bwd, _mhc_projection_fwd_fused, - _mhc_projection_bwd_fused, _mhc_sinkhorn_fwd_fused, - _mhc_sinkhorn_fwd_fused_recompute, _mhc_sinkhorn_bwd_fused, - _mhc_sinkhorn_bwd_fused_recompute, ) from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm -def check_deterministic(operator: str): +def is_deterministic_enforced(): """ - Checks if the non-deterministic algorithm is allowed for the given operator. If not, raises an assertion error with instructions on how to allow it. - Since atomic add is used in this mHC implementation, it breaks the determinism guarantee due to non-associativity of floating point addition. + Check if user enforces deterministic algorithms. We assume non-determinism is allowed if this flag is not set """ - allow_nondeterministic = os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "1" - assert allow_nondeterministic, ( - f"[{operator}]: This operation uses atomic add which violates determinism. Set" - " NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 to allow this non-deterministic behavior." + return os.environ.get("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "0" + + +def mhc_generate_mix_and_aggregate( + x: torch.Tensor, + phi: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + norm_weight: torch.Tensor = None, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, +): + """ + Generate the mix matrix H_pre, H_post, H_res and apply H_pre to x to aggregate n streams + This wraps projection, scale, sinkhorn, and aggregate operations into one function. + + To use mHC in your model: + ``` + layer_input, H_post, H_res = mhc_generate_mix_and_aggregate(x, phi, alpha, beta) + layer_output = layer(layer_input) # Attn / FFN layer + x = mhc_fused_expand_combine(layer_input, bias, H_post, x, H_res) + ``` + + This API accepts both BF16 and FP32 parameters, though the DeepSeek V4 recipe is: + - x: BF16 + - phi, alpha, beta: FP32 + + Parameters + ---------- + x : torch.Tensor, + input tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections, + dtype is torch.bfloat16 or torch.float32 + Note that C is equal to the original hidden dimension divided by n. + phi : torch.Tensor + projection matrix of shape (N, nC), where N=2n+n*n (=24 for n=4), and nC is the hidden dimension after expansion (n times of C), + dtype is torch.bfloat16 or torch.float32 + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 + alpha : torch.Tensor + scaling factor for H, of shape (3,), where + alpha[0] is applied to H[:, 0:n] for H_pre + alpha[1] is applied to H[:, n:2n] for H_post + alpha[2] is applied to H[:, 2n:2n+n*n] for H_res + dtype: torch.bfloat16 or torch.float32 + beta : torch.Tensor + bias term for H, of shape (1, 2*n+n*n), where + beta[0, 0:n] is applied to H[:, 0:n] for H_pre + beta[0, n:2n] is applied to H[:, n:2n] for H_post + beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + dtype is torch.bfloat16 or torch.float32 + use_tf32 : bool + whether to use TF32 for matrix multiplications + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: you must enable this flag for both `mhc_generate_mix_and_aggregate` and `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. + + Returns + ------- + out : torch.Tensor + out of shape (s, b, C), which is the aggregated result after applying H_pre to x, which will be fed into attention / FFN + with the same dtype as x + H_post : torch.Tensor + H_post of shape (s, b, n), which will be used in the post-processing after attention / FFN in `mhc_fused_expand_combine` + with dtype float32 + H_res : torch.Tensor + H_res of shape (s, b, n, n), which will be used to mix the residual connection in `mhc_fused_expand_combine` + with dtype float32 + """ + s, b, C, n = x.shape + assert ( + n == 4 + ), "Only n=4 is supported in this implementation, where n is the Hyper Connection number" + nC = n * C + H, ms = mhc_fused_projection( + x.view(s * b, nC), phi, norm_weight, use_tf32=use_tf32, fuse_grad_x_acc=fuse_grad_x_acc + ) + h_pre, h_post, h_res = mhc_fused_scale(H, alpha, beta, ms, n) + H_pre = h_pre.view(s, b, n) + H_post = h_post.view(s, b, n) + H_res = h_res.view(s, b, n, n) + H_res = mhc_fused_sinkhorn(H_res, n, recompute_hist=True, iters=20) + out = mhc_fused_aggregate( + x, H_pre.view(s, b, n), n, use_tf32=use_tf32, fuse_grad_x_acc=fuse_grad_x_acc ) + return out, H_post, H_res def mhc_fused_sinkhorn( - H_res: torch.Tensor, n: int = 4, recompute_hist: bool = True, iters: int = 20 + H_res: torch.Tensor, + n: int = 4, + recompute_hist: bool = True, + iters: int = 20, ): """ Sinkhorn operation to compute the final H_res matrix (see eq. 19, section 4.3.1 of the DeepSeek mHC paper): @@ -52,6 +133,7 @@ def mhc_fused_sinkhorn( ---------- H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) that needs to be normalized into a doubly stochastic matrix. + dtype is torch.bfloat16 or torch.float32 n : int number of hyper connections, where only n=4 is supported in the current implementation recompute_hist : bool @@ -63,6 +145,7 @@ def mhc_fused_sinkhorn( ------- out : torch.Tensor out of shape (s, b, n, n), which is the final H_res after Sinkhorn normalization + with the same dtype as H_res """ assert n == 4, "Only n=4 is supported in this implementation" out = mHCSinkhornOp.apply(H_res, n, recompute_hist, iters) @@ -70,7 +153,11 @@ def mhc_fused_sinkhorn( def mhc_fused_scale( - H: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ms: torch.Tensor, n: int + H: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + ms: torch.Tensor, + n: int, ): """ Fused scale operation to compute the scaled H matrices (see eq. 16-18, section 4.3.1 of the DeepSeek mHC paper): @@ -96,6 +183,7 @@ def mhc_fused_scale( beta[0, 0:n] is applied to H[:, 0:n] for H_pre beta[0, n:2n] is applied to H[:, n:2n] for H_post beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + Note: we assume alpha and beta have the same dtype, and according to the DeepSeek paper they should be fp32 ms : torch.Tensor mean square for each row of H from the projection kernel, of shape (M,), used for RMSNorm scaling n : int @@ -104,15 +192,17 @@ def mhc_fused_scale( Returns ------- h_pre : torch.Tensor - Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP + Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP, + with the same dtype as H h_post : torch.Tensor - Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection + Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection, + with the same dtype as H h_res : torch.Tensor - Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block + Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block, + with the same dtype as H """ assert n == 4, "Only n=4 is supported in this implementation" - check_deterministic("mhc_fused_scale") out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n) h_pre = out[..., :n] h_post = out[..., n : 2 * n] @@ -120,7 +210,13 @@ def mhc_fused_scale( return h_pre, h_post, h_res -def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: bool = True): +def mhc_fused_aggregate( + x: torch.Tensor, + H_pre: torch.Tensor, + n: int, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, +): """ Aggregate operation to merge n activation streams into one (see section 4.3.1 of the DeepSeek mHC paper): out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C, 1) -> (s, b, C) after squeezing the last dimension @@ -130,22 +226,28 @@ def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: x : torch.Tensor input activation tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections. Note that C is equal to the original hidden dimension divided by n. + dtype is torch.bfloat16 or torch.float32 H_pre: torch.Tensor input H_pre matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 n: int number of hyper connections, where only n=4 is supported in the current implementation use_tf32: bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. Returns ------- out: torch.Tensor - output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections + output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections, + with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" - check_deterministic("mhc_fused_aggregate") - out = mHCAggregateOp.apply(x, H_pre, n, use_tf32) + out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fuse_grad_x_acc) return out @@ -155,8 +257,8 @@ def mhc_fused_expand_combine( H_post: torch.Tensor, x: torch.Tensor, H_res: torch.Tensor, - n: int, use_tf32: bool = True, + fuse_grad_x_acc: bool = False, ): """ Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper): @@ -167,70 +269,97 @@ def mhc_fused_expand_combine( ---------- f : torch.Tensor input activation tensor of shape (s, b, C), which is the output from the attention / FFN sub-layer in a transformer block + dtype is torch.bfloat16 or torch.float32 bias : torch.Tensor or None optional bias tensor of shape (C,) from the last linear layer, where f + bias is fused in this kernel for better performance + dtype is torch.bfloat16 or torch.float32 H_post : torch.Tensor input H_post matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 x : torch.Tensor input activation tensor of shape (s, b, C, n), which is the hyper connection input before the aggregation operation + dtype is torch.bfloat16 or torch.float32 H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) - n : int - number of hyper connections + dtype is torch.bfloat16 or torch.float32 use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_aggregate` or `mhc_generate_mix_and_aggregate` which is a wrapper of the former two, + so they can share the same buffer for activation's gradient accumulation. Returns ------- out : torch.Tensor - out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections + out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections, + with the same dtype as x """ + _, _, _, n = x.shape assert n == 4, "Only n=4 is supported in this implementation" - check_deterministic("mhc_fused_expand_combine") - out = mHCExpandCombineOp.apply( - f, - bias, - H_post, - x, - H_res, - n, - use_tf32, - ) + out = mHCExpandCombineOp.apply(f, bias, H_post, x, H_res, n, use_tf32, fuse_grad_x_acc) return out -def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True): +def mhc_fused_projection( + x: torch.Tensor, + phi: torch.Tensor, + norm_weight: torch.Tensor = None, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, +): """ Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): H = x @ phi^T: (M, K) @ (K, N) -> (M, N), which is padded to (M, 32) for better memory access pattern in the next kernels. ms = mean(x^2, dim=-1): (M,) + If norm_weight is provided, it will be absorbed into phi. In this case, the operation becomes: + Projection: + - H = x @ (phi.T * norm_weight) = x @ phi.T * norm_weight + - ms = mean(x^2, dim=-1) + - H = H / sqrt(ms) = x @ (phi.T * norm_weight) / sqrt(ms), where this step is fused into `mhc_fused_scale` + which is equivalent to performing the computation in the normal order: + - x_normalized = RMSNorm(x) = x * norm_weight / sqrt(ms) + - H = x_normalized @ phi.T = (x / sqrt(ms) @ phi.T) * norm_weight + Note: the current implementation only supports n=4 Parameters ---------- x : torch.Tensor input tensor of shape (M, K), where M=s*b is the batch size and K=nC is the hidden dimension after expansion. + dtype is torch.bfloat16 or torch.float32 phi : torch.Tensor projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4) + dtype is torch.bfloat16 or torch.float32 + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail. + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_aggregate` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. Returns ------- H : torch.Tensor - Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid. + Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid, + with dtype float32 ms : torch.Tensor - Mean square of shape (M,), which is used for RMSNorm in the next kernel. + Mean square of shape (M,), which is used for RMSNorm in the next kernel, + with dtype float32 """ - assert ( - phi.shape[0] == 24 - ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" - check_deterministic("mhc_fused_projection") - H, ms = mHCProjectionOp.apply(x, phi, use_tf32) + assert phi.shape[0] == 24, ( + "Currently only n=4 is supported, which means phi should have 24 (or 32 if you padded phi)" + " in its first dimension" + ) + H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fuse_grad_x_acc) return H, ms @@ -240,16 +369,20 @@ class mHCProjectionOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, phi, use_tf32=True): + def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the fused projection operation. Computes H = x @ phi^T and the mean + If norm_weight is provided, it will be absorbd by phi square ms = mean(x^2, dim=-1) for RMSNorm in a single fused kernel. Parameters: ctx : The context object. x (tensor): The input tensor of shape (M, K), where M=s*b is the flattened batch dimension and K=nC is the hidden dimension after expansion. phi (tensor): The projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4). + norm_weight (tensor or None): Optional, or tensor of shape (K,). RMSNorm's learnable per-element affine parameters use_tf32 (bool): Whether to use TF32 precision for matmul operations. If False, uses IEEE for better precision. + n (int): Number of hyper connections, where only n=4 is supported in the current implementation. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tuple: A tuple of (H, ms) where H is the projected matrix of shape (M, 32) padded for memory alignment (only the first N elements are valid), and ms is the mean square of shape (M,) in FP32. @@ -265,11 +398,15 @@ def forward(ctx, x, phi, use_tf32=True): N = phi.shape[0] + use_determinstic = is_deterministic_enforced() + # Pad H to (s, b, 32) for better memory access pattern in the kernel, but only the first N elements in the last dimension are valid - H = torch.zeros((M, 32), device=device, dtype=torch.float32) - ms = torch.zeros( - (M,), device=device, dtype=torch.float32 - ) # Mean square for x, used to compute RMSNorm in the next kernel + if use_determinstic: + H = torch.empty((M, 32), device=device, dtype=torch.float32) + ms = torch.empty((M,), device=device, dtype=torch.float32) + else: + H = torch.zeros((M, 32), device=device, dtype=torch.float32) + ms = torch.zeros((M,), device=device, dtype=torch.float32) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -277,11 +414,26 @@ def forward(ctx, x, phi, use_tf32=True): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) + precision = "tf32" if ctx.use_tf32 else "ieee" + # If upcasting from bf16 to fp32 takes place inside the triton kernel, triton will ignore "ieee" precision and use tf32 anyway + # See https://github.com/triton-lang/triton/issues/10176 for detail. + # Therefore, we need to use tf32x3 instead which at least has better accuracy than tf32 just to make the tests pass. In production + # precision should be tf32 so it's not affected. + if precision == "ieee" and x.dtype == torch.bfloat16: + # When we have x is bf16, and either + # - phi is fp32, or + # - phi is bf16 but norm_weight is not None, where in this case inside the triton kernel, + # we will promote phi to fp32 because we want better precision for phi * norm_weight + # In both cases we will need to upcast x to fp32 inside the kernel, and trigger the issue mentioned above + if norm_weight is not None or phi.dtype == torch.float32: + precision = "tf32x3" + _mhc_projection_fwd_fused[grid]( x_ptr=x, # (M, K) phi_ptr=phi, # (N, K) h_ptr=H, # (M, 32) ms_ptr=ms, # (M,) + norm_weight_ptr=norm_weight, M=M, N=N, K=K, @@ -292,22 +444,36 @@ def forward(ctx, x, phi, use_tf32=True): stride_hm=32, stride_hn=1, stride_ms=1, + stride_norm_weight=1, BLOCK_SIZE_N=32, - precision="tf32" if use_tf32 else "ieee", + precision=precision, + HAS_NORM_WEIGHT=norm_weight is not None, + DETERMINISTIC=use_determinstic, ) - ctx.save_for_backward(x, phi, ms) + ctx.save_for_backward(x, phi, ms, norm_weight) ctx.phi_dtype = phi.dtype + ctx.precision = precision + ctx.fuse_grad_x_acc = fuse_grad_x_acc - return H.to(ctx.dtype), ms # Keep ms in fp32 + return H, ms # Keep both in fp32, which will be passed to sigmoid in mHCScaleFusedOp @staticmethod def backward(ctx, grad_H, grad_ms): """ The backward pass of the fused projection operation. Computes gradients for x and phi. - grad_phi = grad_H^T @ x, truncated to the first N rows. - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from + - grad_psi = grad_H^T @ x: (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), where grad_H's last dim is padded to 32 + If norm_weight is None: + - grad_phi = grad_psi + Otherwise, + - grad_phi = grad_psi * norm_weight: (2n + n^2, nC) * (nC,) = (2n + n^2, nC) + - grad_norm_weight = sum(grad_psi * phi, dim=0): ((2n + n^2, nC) * (2n + n^2, nC)).sum(dim=0) -> (nC,) + Reorder a bit: + - grad_phi = grad_H^T @ x * norm_weight + - grad_norm_weight = sum((grad_H^T @ x) * phi, dim=0) + + - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from the mean square computation fused in the forward pass. Parameters: @@ -316,9 +482,9 @@ def backward(ctx, grad_H, grad_ms): grad_ms (tensor): The gradient of the loss with respect to the mean square, of shape (M,). Returns: - tuple: A tuple with the gradients (grad_x, grad_phi, None). + tuple: A tuple with the gradients (grad_x, grad_phi, grad_norm_weight, None). """ - x, phi, ms = ctx.saved_tensors + x, phi, ms, norm_weight = ctx.saved_tensors M, K = x.shape device = x.device @@ -332,12 +498,65 @@ def backward(ctx, grad_H, grad_ms): M, ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc + if fuse_grad_x_acc: + grad_x = x.untyped_storage().grad_x_acc.view_as(x) + else: + grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + + if norm_weight is not None: + use_deterministic = is_deterministic_enforced() + # With norm_weight, we need a fused kernel to perform GEMM and output both phi & norm_weight gradients + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(K, META["BLOCK_SIZE_K"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) + + if use_deterministic: + grad_phi = torch.empty_like(phi, dtype=phi.dtype) + else: + grad_phi = torch.zeros_like(phi, dtype=torch.float32) + + if use_deterministic: + grad_norm_weight = torch.empty_like(norm_weight, dtype=norm_weight.dtype) + else: + grad_norm_weight = torch.zeros_like(norm_weight, dtype=torch.float32) + + _mhc_projection_bwd_fused_dphi[grid]( + x_ptr=x, # (M, K) + grad_H_ptr=grad_H, # (M, 32) + phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) + grad_phi_ptr=grad_phi, # (N, K) + grad_norm_weight_ptr=grad_norm_weight, # (K,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_Hm=32, + stride_grad_Hn=1, + stride_phin=K, + stride_phik=1, + stride_norm_weight=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_norm_weight=1, + BLOCK_SIZE_N=32, + precision="tf32" if ctx.use_tf32 else "ieee", + DETERMINISTIC=is_deterministic_enforced(), + ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) - grad_phi = general_gemm(x, grad_H, out_dtype=torch.float32, layout="NT")[0][:N, :].to( - phi.dtype - ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC); grad_H's last dim is padded to 32 + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = grad_norm_weight.to(norm_weight.dtype) + else: + # Without norm_weight, this is only a GEMM with no fusion needed so we let cuBLAS handle it + grad_phi = general_gemm( + x.to(grad_H.dtype), grad_H, out_dtype=torch.float32, layout="NT" + )[0][:N, :] + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = None # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -345,10 +564,11 @@ def backward(ctx, grad_H, grad_ms): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - _mhc_projection_bwd_fused[grid]( + _mhc_projection_bwd_fused_dx[grid]( x_ptr=x, grad_x_ptr=grad_x, # (M, K) phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) grad_h_ptr=grad_H, # (M, 32) grad_ms_ptr=grad_ms, # (M,) M=M, @@ -360,16 +580,22 @@ def backward(ctx, grad_H, grad_ms): stride_grad_xk=1, stride_phin=K, stride_phik=1, + stride_norm_weight=1, stride_grad_phin=K, stride_grad_phik=1, stride_grad_hm=32, stride_grad_hn=1, stride_grad_ms=1, BLOCK_SIZE_N=32, - precision="tf32" if ctx.use_tf32 else "ieee", + precision=ctx.precision, + FUSE_GRAD_X_ACC=fuse_grad_x_acc, + HAS_NORM_WEIGHT=norm_weight is not None, ) - return grad_x.to(ctx.dtype), grad_phi.to(ctx.dtype), None + if fuse_grad_x_acc: + del x.untyped_storage().grad_x_acc + + return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None, None class mHCScaleFusedOp(torch.autograd.Function): @@ -415,8 +641,9 @@ def forward(ctx, H, alpha, beta, ms, n): (M, 32), device=H.device, dtype=H.dtype ) # Pad the output to 32 in the last dimension - # pylint: disable=unnecessary-lambda-assignment - grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 32 + grid = (triton.cdiv(M, BLOCK_SIZE_M),) _mhc_scale_fwd_fused[grid]( h_ptr=H, # (M, N), which is padded to (M, 32) @@ -433,7 +660,8 @@ def forward(ctx, H, alpha, beta, ms, n): stride_ms=1, stride_out_m=32, stride_out_n=1, # strides for out, which is padded to 32 in the last dimension - BLOCK_SIZE_N=32, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, eps=torch.finfo(ms.dtype).eps, ) @@ -459,58 +687,85 @@ def backward(ctx, grad_out): n = ctx.n grad_out = grad_out.contiguous() - grad_out = grad_out.to(torch.float32) M, _ = grad_out.shape N = 2 * n + n * n - grad_h = torch.zeros( - (M, 32), device=grad_out.device, dtype=grad_out.dtype - ) # Pad the grad_h to 32 in the last dimension - grad_alpha = torch.zeros((3,), device=grad_out.device, dtype=grad_out.dtype) - grad_beta_padded = torch.zeros((1, 32), device=grad_out.device, dtype=grad_out.dtype) - grad_beta = grad_beta_padded[ - :, :N - ] # Use only the first N elements for grad_beta, the rest are just padding - grad_ms = torch.zeros((M,), device=grad_out.device, dtype=grad_out.dtype) + grad_H = torch.zeros( + (M, 32), device=grad_out.device, dtype=H.dtype + ) # Pad the grad_H to 32 in the last dimension - # pylint: disable=unnecessary-lambda-assignment - grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),) + use_deterministic = is_deterministic_enforced() + + grad_ms = torch.zeros((M,), device=grad_out.device, dtype=ms.dtype) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 32 + grid = (triton.cdiv(M, BLOCK_SIZE_M),) + + if use_deterministic: + grad_alpha = None + grad_beta_padded = None + workspace_buffer_grad_alpha = torch.empty( + (grid[0], 4), device=grad_out.device, dtype=torch.float32 + ) + workspace_buffer_grad_beta = torch.empty( + (grid[0], 32), device=grad_out.device, dtype=torch.float32 + ) + else: + grad_alpha = torch.zeros((3,), device=grad_out.device, dtype=torch.float32) + grad_beta_padded = torch.zeros((1, 32), device=grad_out.device, dtype=torch.float32) + workspace_buffer_grad_alpha = None + workspace_buffer_grad_beta = None _mhc_scale_bwd_fused[grid]( grad_out_ptr=grad_out, out_ptr=out, - grad_h_ptr=grad_h, - h_ptr=H, + grad_H_ptr=grad_H, + H_ptr=H, grad_a_ptr=grad_alpha, a_ptr=alpha, - grad_b_ptr=grad_beta, + grad_b_ptr=grad_beta_padded, grad_ms_ptr=grad_ms, ms_ptr=ms, + ws_grad_a_ptr=workspace_buffer_grad_alpha, + ws_grad_b_ptr=workspace_buffer_grad_beta, M=M, n=n, stride_grad_out_m=32, stride_grad_out_n=1, stride_out_m=32, stride_out_n=1, - stride_grad_hm=32, - stride_grad_hn=1, - stride_hm=32, - stride_hn=1, + stride_grad_Hm=32, + stride_grad_Hn=1, + stride_Hm=32, + stride_Hn=1, stride_grad_a=1, stride_a=1, stride_grad_b=1, stride_grad_ms=1, stride_ms=1, - BLOCK_SIZE_N=32, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, eps=torch.finfo(ms.dtype).eps, + DETERMINISTIC=use_deterministic, ) + if use_deterministic: + grad_alpha = workspace_buffer_grad_alpha.sum(dim=0)[ + :3 + ] # Sum across blocks and take the first 3 elements for grad_alpha + grad_beta_padded = workspace_buffer_grad_beta.sum( + dim=0, keepdim=True + ) # Sum across blocks for grad_beta + + grad_beta = grad_beta_padded[:, :N] + return ( - grad_h.to(ctx.dtype), - grad_alpha.to(ctx.dtype), - grad_beta.to(ctx.dtype), - grad_ms.to(ctx.dtype), + grad_H, + grad_alpha.to(alpha.dtype), + grad_beta.to(alpha.dtype), # We assume alpha and beta have the same dtype + grad_ms, None, ) @@ -550,7 +805,10 @@ def forward(ctx, H_res, n=4, recompute_hist=True, iters=20): H_res = H_res.contiguous().view(s * b, n * n) hist_f, hist_g = None, None - if not recompute_hist: + if recompute_hist: + hist_f = None + hist_g = None + else: # History buffers: (iters+1, s, b, n) hist_f = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) hist_g = torch.empty((iters + 1, s, b, n), device=H_res.device, dtype=H_res.dtype) @@ -559,32 +817,20 @@ def forward(ctx, H_res, n=4, recompute_hist=True, iters=20): # pylint: disable=unnecessary-lambda-assignment grid = lambda META: (triton.cdiv(s * b * n * n, META["BLOCK_SIZE"]),) - if recompute_hist: - _mhc_sinkhorn_fwd_fused_recompute[grid]( - x_ptr=H_res, - output_ptr=H_res_out, - stride_xm=n * n, - stride_xn=1, - stride_out_m=n * n, - stride_out_n=1, - M=s * b, - n=n, - iters=iters, - ) - else: - _mhc_sinkhorn_fwd_fused[grid]( - x_ptr=H_res, - output_ptr=H_res_out, - hist_f_ptr=hist_f, - hist_g_ptr=hist_g, - stride_xm=n * n, - stride_xn=1, - stride_out_m=n * n, - stride_out_n=1, - M=s * b, - n=n, - iters=iters, - ) + _mhc_sinkhorn_fwd_fused[grid]( + x_ptr=H_res, + output_ptr=H_res_out, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_xm=n * n, + stride_xn=1, + stride_out_m=n * n, + stride_out_n=1, + M=s * b, + n=n, + iters=iters, + RECOMPUTE=recompute_hist, + ) if recompute_hist: ctx.save_for_backward(H_res, H_res_out) @@ -627,53 +873,33 @@ def backward(ctx, grad_out): n = ctx.n - grad_res_out = grad_out.clone().contiguous().view(M, n * n) + grad_res_out = grad_out.contiguous().view(M, n * n) grad_res = torch.empty_like(H_res) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: (triton.cdiv(M * n * n, META["BLOCK_SIZE"]),) - if recompute_hist: - _mhc_sinkhorn_bwd_fused_recompute[grid]( - grad_out_ptr=grad_res_out, - output_ptr=H_res_out, - grad_x_ptr=grad_res, - x_ptr=H_res, - hist_f_ptr=hist_f, - hist_g_ptr=hist_g, - stride_grad_out_m=n * n, - stride_grad_out_n=1, - stride_out_m=n * n, - stride_out_n=1, - stride_grad_xm=n * n, - stride_grad_xn=1, - stride_xm=n * n, - stride_xn=1, - M=M, - n=n, - iters=iters, - ) - else: - _mhc_sinkhorn_bwd_fused[grid]( - grad_out_ptr=grad_res_out, - output_ptr=H_res_out, - grad_x_ptr=grad_res, - x_ptr=H_res, - hist_f_ptr=hist_f, - hist_g_ptr=hist_g, - stride_grad_out_m=n * n, - stride_grad_out_n=1, - stride_out_m=n * n, - stride_out_n=1, - stride_grad_xm=n * n, - stride_grad_xn=1, - stride_xm=n * n, - stride_xn=1, - M=M, - n=n, - iters=iters, - ) + _mhc_sinkhorn_bwd_fused[grid]( + grad_out_ptr=grad_res_out, + output_ptr=H_res_out, + grad_x_ptr=grad_res, + x_ptr=H_res, + hist_f_ptr=hist_f, + hist_g_ptr=hist_g, + stride_grad_out_m=n * n, + stride_grad_out_n=1, + stride_out_m=n * n, + stride_out_n=1, + stride_grad_xm=n * n, + stride_grad_xn=1, + stride_xm=n * n, + stride_xn=1, + M=M, + n=n, + iters=iters, + RECOMPUTE=recompute_hist, + ) grad_res = grad_res.view(s, b, n, n) @@ -686,7 +912,7 @@ class mHCAggregateOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, H_pre, n, use_tf32=True): + def forward(ctx, x, H_pre, n, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the aggregate operation. Merges n activation streams into one by computing a weighted sum using H_pre: @@ -699,6 +925,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): H_pre (tensor): The pre-connection matrix of shape (s, b, n), used as weights for aggregation. n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The aggregated output of shape (s, b, C). @@ -735,6 +962,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): ctx.save_for_backward(x, H_pre) ctx.n = n ctx.use_tf32 = use_tf32 + ctx.fuse_grad_x_acc = fuse_grad_x_acc return out @@ -763,10 +991,21 @@ def backward(ctx, grad_output): assert n == 4, "Only n=4 is supported in this implementation" M = s * b - grad_x = torch.empty_like(x) - grad_H_pre = torch.zeros( - (s, b, n), dtype=torch.float32, device=H_pre.device - ) # We need to use atomic_add for this so we need higher precision + fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc + if fuse_grad_x_acc: + grad_x = x.untyped_storage().grad_x_acc.view_as(x) + else: + grad_x = torch.empty_like(x) + + use_deterministic = is_deterministic_enforced() + if use_deterministic: + grad_H_pre = torch.empty( + (s, b, n), dtype=H_pre.dtype, device=H_pre.device + ) # We need to use atomic_add for this so we need higher precision + else: + grad_H_pre = torch.zeros( + (s, b, n), dtype=torch.float32, device=H_pre.device + ) # We need to use atomic_add for this so we need higher precision # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -790,11 +1029,16 @@ def backward(ctx, grad_output): stride_grad_xm=nC, stride_grad_xCn=1, precision="tf32" if ctx.use_tf32 else "ieee", + FUSE_GRAD_X_ACC=fuse_grad_x_acc, + DETERMINISTIC=use_deterministic, ) grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre - return grad_x, grad_H_pre, None, None + if fuse_grad_x_acc: + grad_x = None + + return grad_x, grad_H_pre, None, None, None class mHCExpandCombineOp(torch.autograd.Function): @@ -803,7 +1047,7 @@ class mHCExpandCombineOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the expand and combine operation. Expands the sub-layer output f back to n streams using H_post, and combines with the residual connections using H_res: @@ -819,6 +1063,7 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): H_res (tensor): The residual connection matrix of shape (s, b, n, n). n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The expanded and combined output of shape (s, b, C, n). @@ -843,45 +1088,29 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): triton.cdiv(M, META["BLOCK_SIZE_M"]), ) - if bias is None: - _mhc_expand_combine_fwd[grid]( - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) - else: - _mhc_expand_combine_with_bias_fwd[grid]( - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) + _mhc_expand_combine_fwd[grid]( + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + HAS_BIAS=bias is not None, + ) ctx.n = n ctx.have_bias = bias is not None + ctx.fuse_grad_x_acc = fuse_grad_x_acc if bias is not None: ctx.save_for_backward(f, bias, H_post, x, H_res) else: @@ -918,15 +1147,39 @@ def backward(ctx, grad_output): f, H_post, x, H_res = ctx.saved_tensors M = s * b + use_deterministic = is_deterministic_enforced() + grad_f = torch.empty_like(f) - grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None - grad_H_post = torch.zeros_like( - H_post, dtype=torch.float32 - ) # We need to use atomic_add for this so we need higher precision grad_x = torch.empty_like(x) - grad_H_res = torch.zeros_like( - H_res, dtype=torch.float32 - ) # We need to use atomic_add for this so we need higher precision + + grad_bias_workspace = None + # Since triton's autotune will reset grad_bias pointer when tuning, we need an empty placeholder here + grad_bias = torch.empty(1, device=grad_output.device, dtype=grad_output.dtype) + if use_deterministic: + grad_H_post = torch.empty_like( + H_post, dtype=H_post.dtype + ) # No need for higher precision since we don't use atomic_add + grad_H_res = torch.empty_like( + H_res, dtype=H_res.dtype + ) # No need for higher precision since we don't use atomic_add + if bias is not None: + # Since grad_bias is reducing over M dimension, we must use a separate workspace for it + # because our kernel parallelizes over M dimension even in deterministic mode + # 4 is the hardcoded BLOCK_SIZE_M in the deterministic mode so we only need to allocate a (M // 4, C) buffer + grad_bias_workspace = torch.empty( + triton.cdiv(M, 4), C, device=bias.device, dtype=torch.float32 + ) + else: + grad_H_post = torch.zeros_like( + H_post, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + grad_H_res = torch.zeros_like( + H_res, dtype=torch.float32 + ) # We need to use atomic_add for this so we need higher precision + if bias is not None: + grad_bias = ( + torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + ) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -934,66 +1187,61 @@ def backward(ctx, grad_output): triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_bias_ws_ptr=grad_bias_workspace, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_bias_ws_m=C, + stride_grad_bias_ws_c=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + HAS_BIAS=bias is not None, + DETERMINISTIC=use_deterministic, + ) + + if use_deterministic and bias is not None: + # Reduce the grad_bias_workspace to get the final grad_bias + grad_bias = grad_bias_workspace.sum(dim=0).to(bias.dtype) + # If no bias, replace the grad_bias placeholder with None if bias is None: - _mhc_expand_combine_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) - else: - _mhc_expand_combine_with_bias_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_bias_ptr=grad_bias, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_bias=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) + grad_bias = None grad_H_post = grad_H_post.to(H_post.dtype) # Cast back to the original dtype of H_post grad_H_res = grad_H_res.to(H_res.dtype) # Cast back to the original dtype of H_res if bias is not None: grad_bias = grad_bias.to(bias.dtype) - return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None + if ctx.fuse_grad_x_acc: + assert not hasattr(x.untyped_storage(), "grad_x_acc"), ( + "Unexpected: grad_x_acc is already attached in x's storage. This implies incorrect" + " usage of `fuse_grad_x_acc` optimization. Please disable fuse_grad_x_acc or check" + " if there are other places where grad_x_acc is attached to x's storage." + ) + # When fused x gradient accumulation is enabled, use fp32 for the accumulation buffer + x.untyped_storage().grad_x_acc = grad_x.to(torch.float32) + grad_x = None + + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None