diff --git a/ci/pytorch.sh b/ci/pytorch.sh index a6ad620fe..016829a8f 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -70,6 +70,7 @@ run_test_config(){ run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py run_default_fa 1 triton_kernels/test_norms.py + run_default_fa 1 triton_kernels/test_grouped_gemm.py NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py run_default_fa 1 test_parallel_cross_entropy.py NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py diff --git a/setup.py b/setup.py index b28644e03..d9b6d8c9c 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,10 @@ def setup_requirements() -> Tuple[List[str], List[str]]: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} - package_data = {"": ["VERSION.txt"]} + package_data = { + "": ["VERSION.txt"], + "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], + } include_package_data = True extras_require = {"test": test_requires} diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b2885d677..ec16740f2 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -11,6 +11,7 @@ import pytest import random +from triton_kernels.test_common import get_tolerances import torch import torch.nn as nn from torch.nn import Parameter @@ -2016,6 +2017,118 @@ def _test_grouped_linear_accuracy( return outputs +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("recipe", [None]) +@pytest.mark.parametrize("fp8_model_params", [False]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [False]) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_triton_accuracy( + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, +): + os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1" + fp8 = recipe is not None + + if IS_HIP_EXTENSION: + if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: + pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8 and recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: + pytest.skip(reason_for_no_fp8_block_scaling) + + config = model_configs[model] + if config.seq_len % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=False, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) + + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + + # Shoule be bit-wise match + atol, rtol = get_tolerances(dtype) + if dtype == torch.float32: + atol = 2.6e-6 + rtol = 5e-2 + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=rtol, atol=atol) + os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "0" + + @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) diff --git a/tests/pytorch/triton_kernels/test_grouped_gemm.py b/tests/pytorch/triton_kernels/test_grouped_gemm.py new file mode 100644 index 000000000..1e7131fca --- /dev/null +++ b/tests/pytorch/triton_kernels/test_grouped_gemm.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# This file is from aiter project (https://github.com/ROCm/aiter) +# commit:04dc719, directory: aiter/op_tests/triton_tests/test_gmm.py + +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +from functools import partial + +# PyTorch +import torch +from torch import Tensor + +# pytest +import pytest + +# AITER: GMM defaults and utility functions +from transformer_engine.pytorch.triton_kernels.gmm.gmm_common import ( + SUPPORTED_DTYPES_STR, + DTYPE, + dtype_from_str, + check_input_device_dtype, + gen_gmm_tensors, + get_gmm_shape, + get_gmm_output, + gen_tgmm_tensors, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, +) + +# AITER: Triton kernel wrappers +from transformer_engine.pytorch.triton_kernels.gmm import ( + gmm as triton_gmm, + ptgmm as triton_ptgmm, + nptgmm as triton_nptgmm, +) + + +# Common code shared by GMM and TGMM unit tests. +# ------------------------------------------------------------------------------ + + +# Shapes. + +# Shapes used only for test purposes. +# fmt: off +TEST_ONLY_SHAPES: list[tuple[int, int, int, int]] = [ + # M, K, N, G + ( 10, 2, 3, 4), + ( 32, 16, 8, 4), + (512, 4096, 2048, 160), +] +# fmt: on + +# Real shapes, used by real models. +# fmt: off +REAL_SHAPES: list[tuple[int, int, int, int]] = [ + # M, K, N, G + ( 49152, 1408, 2048, 64), # deepseekv2-16B + (3145728, 2048, 1408, 8), # deepseekv2-16B + ( 393216, 2048, 1408, 64), # deepseekv2-16B + ( 32768, 6144, 16384, 8), # Mixtral 8x22B + ( 32768, 16384, 6144, 8), # Mixtral 8x22B +] +# fmt: on + +# Test shapes are test only + real ones. +TEST_SHAPES: list[tuple[int, int, int, int]] = TEST_ONLY_SHAPES + REAL_SHAPES + + +# Input and output types. + +INPUT_DTYPES_STR: set[str] = {"i" + dtype_str for dtype_str in SUPPORTED_DTYPES_STR} +OUTPUT_DTYPES_STR: set[str] = {"o" + dtype_str for dtype_str in SUPPORTED_DTYPES_STR} + + +# Transpositions. + +TRANS_LSH_STR: set[str] = {f"tlhs{b}" for b in {"F", "T"}} +TRANS_RHS_STR: set[str] = {f"trhs{b}" for b in {"F", "T"}} + + +def trans_from_str(trans_str: str, tensor_str: str) -> bool: + assert tensor_str in {"lhs", "rhs"}, f"Invalid tensor string ({tensor_str})." + return trans_str.replace(f"t{tensor_str}", "") == "T" + + +trans_lhs_from_str = partial(trans_from_str, tensor_str="lhs") +trans_rhs_from_str = partial(trans_from_str, tensor_str="rhs") + + +# RNG seed. + +RNG_SEED_STR: set[str] = {f"rng{rng_seed}" for rng_seed in {77, 121}} + + +def rng_seed_from_str(rng_seed_str: str) -> int: + rng_seed_int = int(rng_seed_str.replace("rng", "")) + assert rng_seed_int >= 0, f"RNG seed must be non-negative (it's {rng_seed_int})." + return rng_seed_int + + +# Number of distinct group sizes for each test shape. +NUM_GROUP_SIZES: int = 5 + + +# Tensor comparison. +def check_tensors( + actual: Tensor, + expected: Tensor, + msg: str, + atol: float | None = None, + rtol: float | None = None, +) -> None: + if atol is None: + atol = 5e-3 + else: + assert atol > 0, f"Absolute tolerance must be positive (it's {atol})." + if rtol is None: + rtol = 1e-2 + else: + assert rtol > 0, f"Relative tolerance must be positive (it's {rtol})." + torch.testing.assert_close( + actual, + expected, + atol=atol, + rtol=rtol, + msg=lambda torch_msg: f"{msg}\n\n{torch_msg}\n", + ) + + +# GMM unit tests. +# ------------------------------------------------------------------------------ + + +def torch_gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + bias: Tensor | None = None, +) -> Tensor: + check_input_device_dtype(lhs, rhs, group_sizes) + + M, _, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + last_row = 0 + + for g in range(G): + m = int(group_sizes[g].item()) + + # Skip group if there are no tokens assigned to the expert. + if m == 0: + continue + + start_idx = last_row + end_idx = last_row + m + + result = (lhs[start_idx:end_idx, :] @ rhs[g]).to(torch.float32) + if bias is not None: + result += bias[g].to(torch.float32) + out[start_idx:end_idx, :] = result.to(preferred_element_type) + + last_row += m + + return out + + +@pytest.mark.parametrize("M, K, N, G", TEST_SHAPES) +@pytest.mark.parametrize("in_dtype_str", INPUT_DTYPES_STR) +@pytest.mark.parametrize("out_dtype_str", OUTPUT_DTYPES_STR) +@pytest.mark.parametrize("trans_rhs_str", TRANS_RHS_STR) +@pytest.mark.parametrize("rng_seed_str", RNG_SEED_STR) +@pytest.mark.parametrize("use_bias", [False, True]) +def test_gmm( + M: int, + K: int, + N: int, + G: int, + in_dtype_str: str, + out_dtype_str: str, + trans_rhs_str: str, + rng_seed_str: str, + use_bias: bool, +): + in_dtype = dtype_from_str(in_dtype_str) + out_dtype = dtype_from_str(out_dtype_str) + trans_rhs = trans_rhs_from_str(trans_rhs_str) + rng_seed = rng_seed_from_str(rng_seed_str) + + lhs, rhs, multiple_group_sizes, out_torch, bias = gen_gmm_tensors( + M, + K, + N, + G, + NUM_GROUP_SIZES, + input_type=in_dtype, + output_type=out_dtype, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=True, # 1st group_sizes in test is evenly distributed + use_bias=use_bias, + ) + out_triton = torch.empty_like(out_torch) + + for group_sizes in multiple_group_sizes: + torch_gmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_torch, + bias=bias, + ) + + triton_gmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_triton, + bias=bias, + ) + + m = int(torch.sum(group_sizes).item()) + + # Tolerance handling: + # - Default (no bias): use strict global defaults (atol=5e-3, rtol=1e-2) + # - With bias: allow slightly looser tolerances due to: + # * extra floating point op (add bias) + # * large problem sizes and mixed precision + # * very small fraction of elements differing by a few bf16/fp16 ULPs + if use_bias: + # Base tolerances for bias case. + atol = 0.02 + rtol = 0.02 + else: + atol = None + rtol = None + + check_tensors( + out_triton[:m], + out_torch[:m], + "Triton GMM doesn't match PyTorch reference GMM.", + atol=atol, + rtol=rtol, + ) + + +# TGMM unit tests. +# ------------------------------------------------------------------------------ + + +def torch_tgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + # Bias gradient handling (test/reference only). + # Get or validate bias gradient tensor (validates and optionally zeros it). + compute_bias_grad = bias_grad is not None + bias_grad = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + last_col = 0 + + for g in range(G): + m = int(group_sizes[g].item()) + + # Skip group if there are no columns assigned to the group. + if m == 0: + continue + + start_idx = last_col + end_idx = last_col + m + mm = lhs[:, start_idx:end_idx] @ rhs[start_idx:end_idx, :] + out[g] = mm.to(preferred_element_type) + + # Bias gradient: sum lhs across m-dimension (columns) for each group. + if compute_bias_grad: + grad = lhs[:, start_idx:end_idx].sum(dim=1, dtype=torch.float32) + bias_grad[g] += grad + + last_col += m + + return out + + +@pytest.mark.parametrize("persistent_str", {"p", "np"}) +@pytest.mark.parametrize("with_bias_grad", [False, True]) +@pytest.mark.parametrize("M, K, N, G", TEST_SHAPES) +@pytest.mark.parametrize("in_dtype_str", INPUT_DTYPES_STR) +@pytest.mark.parametrize("out_dtype_str", OUTPUT_DTYPES_STR) +@pytest.mark.parametrize("trans_lhs_str", TRANS_LSH_STR) +@pytest.mark.parametrize("rng_seed_str", RNG_SEED_STR) +def test_tgmm( + persistent_str: str, + with_bias_grad: bool, + M: int, + K: int, + N: int, + G: int, + in_dtype_str: str, + out_dtype_str: str, + trans_lhs_str: str, + rng_seed_str: str, +): + assert persistent_str in {"p", "np"} + persistent: bool = persistent_str == "p" + + in_dtype = dtype_from_str(in_dtype_str) + out_dtype = dtype_from_str(out_dtype_str) + trans_lhs = trans_lhs_from_str(trans_lhs_str) + rng_seed = rng_seed_from_str(rng_seed_str) + + lhs, rhs, multiple_group_sizes, out_torch, bias_grad_torch = gen_tgmm_tensors( + M, + K, + N, + G, + NUM_GROUP_SIZES, + input_type=in_dtype, + output_type=out_dtype, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=True, # 1st group_sizes in test is evenly distributed + use_bias=with_bias_grad, + ) + out_triton = torch.empty_like(out_torch) + bias_grad_triton = torch.empty_like(bias_grad_torch) if with_bias_grad else None + + # For big shape (M, K, N, G) = (3145728, 2048, 1408, 8) there are some element + # mismatches (125 / 23068672 ~ 0.00013%) with absolute error greater than the + # default tolerance. This behavior is deterministic and, given a RNG seed, + # always happen for the same output elements. So, absolute tolerance is increased + # only for this shape. + atol = 2.5e-2 if M > 1e6 else None + + kernel_wrapper = triton_ptgmm if persistent else triton_nptgmm + + for group_sizes in multiple_group_sizes: + # Reference implementation. + torch_tgmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_torch, + bias_grad=bias_grad_torch, + accumulate=False, + ) + + # Triton kernel. + kernel_wrapper( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_triton, + bias_grad=bias_grad_triton, + accumulate=False, + ) + non_empty_groups = group_sizes > 0 + + # Compare TGMM outputs. + check_tensors( + out_triton[non_empty_groups], + out_torch[non_empty_groups], + f"Triton {'persistent' if persistent else 'non-persistent'} TGMM doesn't match PyTorch reference TGMM.", + atol=atol, + ) + + # For persistent TGMM, also compare bias gradients on smaller shapes. + # + # For very large shapes (e.g., M > 1e6), bias_grad is an extremely long + # float32 reduction with atomics in the Triton kernel and a different + # reduction order in the PyTorch reference. Per-element comparisons + # become dominated by reduction-order noise rather than meaningful + # correctness checks, so we skip bias_grad comparison there and rely + # only on the output tensor check above. + if with_bias_grad and M <= 1e6: + bias_atol = 1.7 + bias_rtol = 0.1 + + check_tensors( + bias_grad_triton[non_empty_groups], + bias_grad_torch[non_empty_groups], + "Triton persistent TGMM bias_grad doesn't match PyTorch reference TGMM bias_grad.", + atol=bias_atol, + rtol=bias_rtol, + ) + + +@pytest.mark.parametrize("persistent_str", {"p", "np"}) +@pytest.mark.parametrize("with_bias_grad", [False, True]) +def test_tgmm_accumulate(persistent_str: str, with_bias_grad: bool): + persistent: bool = persistent_str == "p" + + """Test ACCUMULATE semantics for persistent TGMM on a small, focused case.""" + # Use the smallest TEST_ONLY_SHAPES entry to keep runtime low. + M, K, N, G = TEST_ONLY_SHAPES[0] + + in_dtype = DTYPE + out_dtype = DTYPE + trans_lhs = False + rng_seed = 77 + + lhs, rhs, multiple_group_sizes, out_torch, bias_grad_torch = gen_tgmm_tensors( + M, + K, + N, + G, + NUM_GROUP_SIZES, + input_type=in_dtype, + output_type=out_dtype, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=True, + use_bias=with_bias_grad, + ) + + # Take a single group_sizes configuration for this targeted test. + group_sizes = multiple_group_sizes[0] + non_empty_groups = group_sizes > 0 + + # Base output to accumulate into. + base_out = torch.randn_like(out_torch) + + # Reference: compute TGMM delta into a fresh buffer, then add to base_out. + delta_ref = torch.empty_like(out_torch) + torch_tgmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=delta_ref, + bias_grad=bias_grad_torch, + accumulate=False, + ) + expected = base_out.clone() + expected[non_empty_groups] = ( + expected[non_empty_groups] + delta_ref[non_empty_groups] + ) + + # Triton PTGMM/NPTGMM with ACCUMULATE=True. + out_triton = base_out.clone() + bias_grad_triton = torch.empty_like(bias_grad_torch) if with_bias_grad else None + + if persistent: + triton_ptgmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_triton, + bias_grad=bias_grad_triton, + accumulate=True, + ) + else: + triton_nptgmm( + lhs, + rhs, + group_sizes, + preferred_element_type=out_dtype, + existing_out=out_triton, + bias_grad=bias_grad_triton, + accumulate=True, + ) + + check_tensors( + out_triton[non_empty_groups], + expected[non_empty_groups], + "Triton persistent TGMM ACCUMULATE semantics do not match reference behavior.", + ) + + # Check bias_grad + if with_bias_grad: + check_tensors( + bias_grad_triton[non_empty_groups], + bias_grad_torch[non_empty_groups], + "Triton persistent TGMM bias_grad with ACCUMULATE=True does not match reference.", + ) \ No newline at end of file diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9f3921d36..42722130d 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -134,6 +134,7 @@ def general_grouped_gemm( use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, single_output=False, + **kwargs, ) -> Tuple[List[torch.Tensor], ...]: """ TN layout Grouped GEMM with fp8 inputs. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da66e68b4..9443d679f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,12 +3,14 @@ # See LICENSE for license information. """GroupedLinear API""" +import os from typing import Union, Optional, Callable, Tuple, List import warnings import functools import torch +from transformer_engine.pytorch.triton_kernels.grouped_gemm import general_grouped_gemm_triton import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -49,6 +51,7 @@ prepare_for_saving, restore_from_saved, ) +from torch.utils.cpp_extension import IS_HIP_EXTENSION __all__ = ["GroupedLinear"] @@ -80,10 +83,13 @@ def forward( module, skip_fp8_weight_update, save_original_input, + m_splits_tensor: Optional[torch.Tensor], # Optional GPU tensor for triton kernel *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring + # Check if Triton kernel should be used + use_grouped_gemm_triton = IS_HIP_EXTENSION and os.getenv("NVTE_USE_GROUPED_GEMM_TRITON", "0") == "1" and not fp8 and not fuse_wgrad_accumulation num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] @@ -126,8 +132,10 @@ def forward( if fp8: inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) else: - inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) - + if not use_grouped_gemm_triton: + inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + else: + inputmats = [cast_if_needed(inp_view, activation_dtype)] # Initialize weights weights_fp8: list if fp8: @@ -168,17 +176,26 @@ def forward( use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator # Perform GEMM - _ = general_grouped_gemm( + general_grouped_gemm_func = general_grouped_gemm_triton if use_grouped_gemm_triton else general_grouped_gemm + # Prepare m_splits for each backend + m_splits_for_kernel = m_splits + if use_grouped_gemm_triton: + # Triton kernel needs GPU tensor + if m_splits_tensor is None: + m_splits_tensor = torch.tensor(m_splits, dtype=torch.int32, device=device) + m_splits_for_kernel = m_splits_tensor + _ = general_grouped_gemm_func( weights_fp8, inputmats, [out], activation_dtype, get_multi_stream_cublas_workspace(), single_output=True, - m_splits=m_splits, + m_splits=m_splits_for_kernel, bias=biases, use_bias=use_bias, use_split_accumulator=use_split_accumulator, + m_splits_list=m_splits, ) if fp8_calibration: @@ -235,6 +252,7 @@ def forward( ctx.device = device ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits + ctx.m_splits_for_kernel = m_splits_for_kernel ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 @@ -255,6 +273,8 @@ def forward( ctx.wgrad_store = wgrad_store ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.use_grouped_gemm_triton = use_grouped_gemm_triton + ctx.num_input_tensors = len(inputmats) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -265,10 +285,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], with torch.cuda.nvtx.range("_GroupedLinear_backward"): saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) N = ctx.num_gemms - inputmats = saved_tensors[:N] - weights = saved_tensors[N : 2 * N] - origin_weights = saved_tensors[2 * N : 3 * N] - biases = saved_tensors[3 * N : 4 * N] + num_inputs = ctx.num_input_tensors + inputmats = saved_tensors[:num_inputs] + weights = saved_tensors[num_inputs: num_inputs + N] + origin_weights = saved_tensors[num_inputs + N : num_inputs + 2 * N] + biases = saved_tensors[num_inputs + 2 * N : num_inputs + 3 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -311,10 +332,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Only split grad output. Grad bias is fused with # wgrad GEMM. - grad_output = torch.split( + if not ctx.use_grouped_gemm_triton: + grad_output = torch.split( cast_if_needed(grad_output_view, ctx.activation_dtype), ctx.m_splits, ) + else: + grad_output = [cast_if_needed(grad_output_view, ctx.activation_dtype)] if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -343,7 +367,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], rowwise_usage=quantizer.rowwise_usage, columnwise_usage=quantizer.columnwise_usage, ) - general_grouped_gemm( + general_grouped_gemm_func = general_grouped_gemm_triton if ctx.use_grouped_gemm_triton else general_grouped_gemm + general_grouped_gemm_func( weights, grad_output, [dgrad], @@ -351,9 +376,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], get_multi_stream_cublas_workspace(), single_output=True, layout="NN", - m_splits=ctx.m_splits, + m_splits=ctx.m_splits_for_kernel, grad=True, use_split_accumulator=dgrad_gemm_use_split_accumulator, + m_splits_list=ctx.m_splits, ) if ctx.weights_requires_grad: @@ -367,10 +393,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) - for w in weights - ] + if not ctx.use_grouped_gemm_triton: + wgrad_list = [ + torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) + for w in weights + ] + else: + wgrad_list = torch.empty( + (ctx.num_gemms, weights[0].size(0), weights[0].size(1)), + dtype=ctx.activation_dtype, + device=ctx.device + ) if ctx.save_original_input: inp = inputmats[0] @@ -388,17 +421,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.fp8: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) else: - inputmats = torch.split( - cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits - ) + if not ctx.use_grouped_gemm_triton: + inputmats = torch.split( + cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + ) + else: + inputmats = [cast_if_needed(inp_view, ctx.activation_dtype)] grouped_gemm_wgrad = functools.partial( - general_grouped_gemm, + general_grouped_gemm_func, out_dtype=ctx.activation_dtype, workspaces=get_multi_stream_cublas_workspace(), layout="NT", grad=True, - m_splits=ctx.m_splits, + m_splits=ctx.m_splits_for_kernel, use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, use_split_accumulator=wgrad_gemm_use_split_accumulator, @@ -446,8 +482,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad): return wgrad wgrad_list = [ - handle_custom_ddp_from_mcore(weight, wgrad) - for weight, wgrad in zip(origin_weights, wgrad_list) + handle_custom_ddp_from_mcore(weight, wgrad_list[i]) + for i, weight in enumerate(origin_weights) ] else: wgrad_list = [None] * ctx.num_gemms @@ -484,6 +520,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -706,6 +743,7 @@ def forward( inp: torch.Tensor, m_splits: List[int], is_first_microbatch: Optional[bool] = None, + m_splits_tensor: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -798,6 +836,7 @@ def forward( self, skip_fp8_weight_update, self.save_original_input, + m_splits_tensor, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/triton_kernels/common.py b/transformer_engine/pytorch/triton_kernels/common.py index 71afb2fd3..6b1ca0462 100644 --- a/transformer_engine/pytorch/triton_kernels/common.py +++ b/transformer_engine/pytorch/triton_kernels/common.py @@ -5,6 +5,7 @@ import triton import triton.language as tl import transformer_engine_torch as tex +from functools import lru_cache def is_cdna4(): return triton.runtime.driver.active.get_current_target().arch == "gfx950" @@ -98,3 +99,6 @@ def get_fp8_max(dtype: tex.DType): if dtype == tex.DType.kFloat8E5M2: return 57344.0 +@lru_cache(maxsize=1) +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/__init__.py b/transformer_engine/pytorch/triton_kernels/gmm/__init__.py new file mode 100644 index 000000000..cf914e493 --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + +"""GMM (Grouped Matrix Multiplication) kernels.""" + +from .gmm_wrapper import gmm, ptgmm, nptgmm + +__all__ = ["gmm", "ptgmm", "nptgmm"] \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx942-GMM.json b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx942-GMM.json new file mode 100644 index 000000000..7785e5aeb --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx942-GMM.json @@ -0,0 +1,204 @@ +{ + "gmm": { + "default": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "tiny_shapes": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 4, + "num_stages": 1 + }, + "k_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 8, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "n_heavy": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "balanced_large_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "very_large_m": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "small_shapes": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "small_m_moderate_n": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 304, + "num_warps": 4, + "num_stages": 1 + }, + "default_no_trans_rhs": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "k_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 8, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "very_large_m_small_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "balanced_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 304, + "num_warps": 4, + "num_stages": 1 + }, + "n_very_heavy_bwd": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "small_k_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + } + + }, + "ptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 4, + "num_stages": 1 + }, + "high_group_count": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 4, + "num_stages": 1 + }, + "small_n_high_group": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + }, + "accumulate": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 304, + "num_warps": 8, + "num_stages": 1 + } + }, + "nptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "num_warps": 4, + "num_stages": 1 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 1 + }, + "accumulate": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 1 + } + } +} \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx950-GMM.json b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx950-GMM.json new file mode 100644 index 000000000..49591091d --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx950-GMM.json @@ -0,0 +1,203 @@ +{ + "gmm": { + "default": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "tiny_shapes": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "k_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_shapes": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_m_moderate_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "default_no_trans_rhs": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "k_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m_small_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_very_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_k_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + } + }, + "ptgmm": { + "default": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + }, + "high_group_count": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + }, + "small_n_high_group": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + } + }, + "nptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 16, + "num_warps": 8, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + } + } +} \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_common.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_common.py new file mode 100644 index 000000000..469b23b67 --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_common.py @@ -0,0 +1,741 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# This file is from aiter project (https://github.com/ROCm/aiter) +# commit:04dc719, directory: aiter/ops/triton/utils/gmm_common.py + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import torch +from torch import Tensor + + +# Supported data types. +# ------------------------------------------------------------------------------ + +# Supported data types, as strings. +SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"} + + +# Convert string data type to PyTorch data type. +def dtype_from_str(dtype_str: str) -> torch.dtype: + dtype_str = dtype_str.strip().lower() + dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str + assert ( + dtype_str in SUPPORTED_DTYPES_STR + ), "String data type isn't in set of supported string data types." + return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] + + +# Supported data types, as PyTorch types. +SUPPORTED_DTYPES: set[torch.dtype] = { + dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR +} + + +# Convert PyTorch data type to string data type. +def str_from_dtype(dtype: torch.dtype) -> str: + assert ( + dtype in SUPPORTED_DTYPES + ), "PyTorch data type isn't in set of supported PyTorch data types." + return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype] + + +# Default data type, as string. +DTYPE_STR: str = "bf16" +assert ( + DTYPE_STR in SUPPORTED_DTYPES_STR +), "Default string data type isn't in set of supported string data types." + + +# Default data type, as PyTorch type. +DTYPE: torch.dtype = dtype_from_str(DTYPE_STR) + + +# Other defaults. +# ------------------------------------------------------------------------------ + +# Default device. +DEVICE: torch.device | str = "cuda" + +# Default RNG seed for input generation. +RNG_SEED: int = 0 + +# Default number of group sizes. +NUM_GROUP_SIZES: int = 1 + +# Default transposition (NN). +TRANS_LHS: bool = False +TRANS_RHS: bool = False + + +# Parameter checking functions. +# ------------------------------------------------------------------------------ + + +def is_power_of_2(x: int) -> bool: + return (x > 0) and (x & (x - 1) == 0) + + +def check_input_device_dtype( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None +) -> None: + assert ( + lhs.device == rhs.device == group_sizes.device + ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})." + assert ( + lhs.dtype == rhs.dtype + ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})." + assert group_sizes.dtype == torch.int32, "group_sizes type must be int32." + + if bias is not None: + assert ( + bias.device == lhs.device + ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})." + assert ( + bias.dtype == lhs.dtype + ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})." + + +def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None: + assert bias.shape == ( + G, + N, + ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}." + assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))." + + +# Generation of group sizes. +# ------------------------------------------------------------------------------ + + +# Probabilities for generating random group sizes. +UNUSED_TOKENS_PROB: float = 0.0 +UNUSED_EXPERTS_PROB: float = 0.1 + + +def gen_uniform_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + + base = M // G + remainder = M % G + group_sizes = torch.full((G,), base, dtype=torch.int32, device=device) + if remainder > 0: + group_sizes[:remainder] += 1 + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == M + ), f"Group sizes don't add up to total tokens {M}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_group_sizes( + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, +) -> Tensor: + assert M >= 0, f"Number of tokens M must be non-negative (it's {M})." + assert G > 0, f"Number of experts G must be positive (it's {G})." + assert ( + 0 <= unused_tokens_prob <= 1 + ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})." + assert ( + 0 <= unused_experts_prob <= 1 + ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if unused_tokens_prob > 0: + # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed. + num_unused_tokens = M + while num_unused_tokens == M: + num_unused_tokens = int( + torch.binomial( + torch.tensor(float(M), device=device), + torch.tensor(unused_tokens_prob, device=device), + ).item() + ) + else: + num_unused_tokens = 0 + num_used_tokens = M - num_unused_tokens + assert ( + num_unused_tokens >= 0 + ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})." + assert ( + num_used_tokens > 0 + ), f"Number of used tokens must be positive (it's {num_used_tokens})." + assert ( + num_used_tokens + num_unused_tokens == M + ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})." + + + if unused_experts_prob > 0: + # Some experts may have zero tokens assigned to them. + num_used_experts = 0 + while num_used_experts == 0: + used_experts = torch.nonzero( + torch.rand((G,), device=device) >= unused_experts_prob + ).squeeze() + num_used_experts = used_experts.numel() + else: + used_experts = torch.arange(0, G, device=device) + num_used_experts = G + num_unused_experts = G - num_used_experts + assert ( + num_unused_experts >= 0 + ), f"Number of unused experts must be non-negative (it's {num_unused_experts})." + assert ( + num_used_experts >= 1 + ), f"At least one expert must be used (it's {num_used_experts})." + assert ( + num_unused_experts + num_used_experts == G + ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})." + + group_sizes = torch.bincount( + used_experts[ + torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,)) + ], + minlength=G, + ).to(torch.int32) + + assert ( + len(group_sizes) == G + ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})." + assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative." + assert ( + torch.sum(group_sizes).item() == num_used_tokens + ), f"Group sizes don't add up to used tokens {num_used_tokens}." + assert group_sizes.dtype == torch.int32, "Group sizes must be int32." + + return group_sizes + + +def gen_multiple_group_sizes( + num_group_sizes: int, + M: int, + G: int, + device: torch.device | str = DEVICE, + rng_seed: int | None = RNG_SEED, + unused_tokens_prob: float = UNUSED_TOKENS_PROB, + unused_experts_prob: float = UNUSED_EXPERTS_PROB, + group_sizes_0: Tensor | None = None, +) -> list[Tensor]: + assert ( + num_group_sizes > 0 + ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}." + multiple_group_sizes = [ + gen_group_sizes( + M, + G, + device=device, + rng_seed=rng_seed if g == 0 else None, + unused_tokens_prob=unused_tokens_prob, + unused_experts_prob=unused_experts_prob, + ) + for g in range( + num_group_sizes if group_sizes_0 is None else num_group_sizes - 1 + ) + ] + if group_sizes_0 is not None: + multiple_group_sizes.insert(0, group_sizes_0) + assert ( + len(multiple_group_sizes) == num_group_sizes + ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})." + return multiple_group_sizes + + +# GMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_gmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert M > 0, f"Number of lhs rows M must be positive (M = {M})." + assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + lhs = torch.randn((M, K), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + if trans_rhs: + rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute( + 0, 2, 1 + ) + else: + rhs = torch.randn((G, K, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + out = torch.empty((M, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_gmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = False, + trans_rhs: bool = TRANS_RHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor]: + lhs, rhs, group_sizes_0 = gen_gmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_rhs=trans_rhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type) + bias = None + if use_bias: + torch.manual_seed(rng_seed + 1000) # Different seed for bias + bias = torch.randn(G, N, dtype=input_type, device=device) + + return lhs, rhs, multiple_group_sizes, out, bias + + +# GMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_gmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + M, lhs_k = lhs.shape + rhs_g, rhs_k, N = rhs.shape + group_sizes_g = group_sizes.shape[0] + + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_g == group_sizes_g + ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})." + G = rhs_g + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_gmm_output( + M: int, + N: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert M > 0, f"Number of out rows M must be positive (M = {M})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + M, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})." + return existing_out + + return gen_gmm_output( + M, + N, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})." + assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})." + + lhs_m, lhs_k = lhs.shape + G, rhs_k, rhs_n = rhs.shape + out_m, out_n = out.shape + + assert ( + lhs_m == out_m + ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})." + M = lhs_m + assert ( + lhs_k == rhs_k + ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (K, 1) + assert is_lhs_row_major, "lhs must be row-major." + is_rhs_row_major = rhs.stride() == (K * N, N, 1) + is_rhs_col_major = rhs.stride() == (K * N, 1, K) + assert ( + is_rhs_row_major != is_rhs_col_major + ), "rhs must be row-major or column-major." + is_out_row_major = out.stride() == (N, 1) + assert is_out_row_major, "out must be row-major." + + # Get rhs leading dimension according to transposition configuration. + ld_rhs = N if is_rhs_row_major else K + + return is_rhs_col_major, ld_rhs + + +# TGMM helpers: tensor generation. +# ------------------------------------------------------------------------------ + + +def gen_tgmm_input( + M: int, + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, +) -> tuple[Tensor, Tensor, Tensor]: + assert K > 0, f"Number of lhs rows K must be positive (M = {K})." + assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})." + assert N > 0, f"Number of rhs columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if rng_seed is not None: + torch.manual_seed(rng_seed) + + if trans_lhs: + lhs = torch.randn((M, K), dtype=torch.float32, device=device).T + else: + lhs = torch.randn((K, M), dtype=torch.float32, device=device) + lhs = lhs.to(preferred_element_type) + + rhs = torch.randn((M, N), dtype=torch.float32, device=device) + rhs = rhs.to(preferred_element_type) + + group_sizes = ( + gen_uniform_group_sizes(M, G, device=device) + if unif_group_sizes + else gen_group_sizes(M, G, device=device, rng_seed=None) + ) + + return lhs, rhs, group_sizes + + +def gen_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + out = torch.empty((G, K, N), dtype=preferred_element_type, device=device) + + return out + + +def gen_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + with_bias_grad: bool = False, +) -> Tensor: + if with_bias_grad: + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + return torch.empty((G, K), device=device, dtype=torch.float32) + else: + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) + + +def gen_tgmm_tensors( + M: int, + K: int, + N: int, + G: int, + num_group_sizes: int, + device: torch.device | str = DEVICE, + input_type: torch.dtype = DTYPE, + output_type: torch.dtype = DTYPE, + trans_lhs: bool = TRANS_LHS, + trans_rhs: bool = False, + rng_seed: int | None = RNG_SEED, + unif_group_sizes: bool = False, + use_bias: bool = False, +) -> tuple[Tensor, Tensor, list[Tensor], Tensor]: + lhs, rhs, group_sizes_0 = gen_tgmm_input( + M, + K, + N, + G, + device=device, + preferred_element_type=input_type, + trans_lhs=trans_lhs, + rng_seed=rng_seed, + unif_group_sizes=unif_group_sizes, + ) + multiple_group_sizes = gen_multiple_group_sizes( + num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0 + ) + out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type) + if use_bias: + bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True) + else: + bias_grad = None + return lhs, rhs, multiple_group_sizes, out, bias_grad + + +# TGMM helpers: get information from tensors. +# ------------------------------------------------------------------------------ + + +def get_tgmm_shape( + lhs: Tensor, rhs: Tensor, group_sizes: Tensor +) -> tuple[int, int, int, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert ( + group_sizes.dim() == 1 + ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})." + + K, lhs_m = lhs.shape + rhs_m, N = rhs.shape + G = group_sizes.shape[0] + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + return M, K, N, G + + +def get_tgmm_output( + K: int, + N: int, + G: int, + device: torch.device | str = DEVICE, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, +) -> Tensor: + assert K > 0, f"Number of out rows K must be positive (K = {K})." + assert N > 0, f"Number of out columns N must be positive (N = {N})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_out is not None: + assert ( + existing_out.device == device + ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})." + assert ( + existing_out.dtype == preferred_element_type + ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})." + assert existing_out.shape == ( + G, + K, + N, + ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})." + return existing_out + + return gen_tgmm_output( + K, + N, + G, + device=device, + preferred_element_type=preferred_element_type, + ) + + +def get_tgmm_bias_grad( + K: int, + G: int, + device: torch.device | str = DEVICE, + existing_bias_grad: Tensor | None = None, +) -> Tensor: + """ + Get or validate bias gradient tensor for TGMM. + + If existing_bias_grad is provided, validates its shape, device, dtype, and stride, + and always zeros it before returning (since the kernel uses atomic_add). + If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False). + Parameters + ---------- + K : int + Number of rows in the bias gradient tensor. + G : int + Number of groups. + device : torch.device or str + Device for the tensor. + existing_bias_grad : torch.Tensor or None + Existing bias gradient tensor to validate and use. + Returns + ------- + torch.Tensor + Valid bias gradient tensor or dummy tensor. + """ + assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})." + assert G > 0, f"Number of groups G must be positive (G = {G})." + + if existing_bias_grad is not None: + # Validate existing bias_grad tensor. + expected_shape = (G, K) + assert ( + tuple(existing_bias_grad.shape) == expected_shape + ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}." + assert ( + existing_bias_grad.device == device + ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})." + assert ( + existing_bias_grad.dtype == torch.float32 + ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}." + assert existing_bias_grad.stride() == ( + K, + 1, + ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}." + + # Always zero the tensor since bias_grad represents gradients for the current + # computation and should start fresh. The kernel uses atomic_add which adds to + # existing values, so we must zero before the kernel runs. + existing_bias_grad.zero_() + + return existing_bias_grad + + else: + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + + +def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: + assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})." + assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})." + assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})." + + lhs_k, lhs_m = lhs.shape + rhs_m, rhs_n = rhs.shape + G, out_k, out_n = out.shape + + assert ( + lhs_m == rhs_m + ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})." + M = lhs_m + assert ( + lhs_k == out_k + ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})." + K = lhs_k + assert ( + rhs_n == out_n + ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})." + N = rhs_n + + assert M > 0, f"M must be positive, it's {M}." + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}" + assert G > 0, f"G must be positive, it's {G}" + + is_lhs_row_major = lhs.stride() == (M, 1) + is_lhs_col_major = lhs.stride() == (1, K) + assert ( + is_lhs_row_major != is_lhs_col_major + ), "lhs must be row-major or column-major." + is_rhs_row_major = rhs.stride() == (N, 1) + assert is_rhs_row_major, "rhs must be row-major." + is_out_row_major = out.stride() == (K * N, N, 1) + assert is_out_row_major, "out must be row-major." + + # Get lhs leading dimension according to transposition configuration. + ld_lhs = M if is_lhs_row_major else K + + return is_lhs_col_major, ld_lhs diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py new file mode 100644 index 000000000..04653eaca --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# This file is from aiter project (https://github.com/ROCm/aiter) and modified for compatibility with transformer_engine. +# commit: 56f4d93, directory: aiter/ops/triton/_triton_kernels/gmm.py +# Imports. +# ------------------------------------------------------------------------------ + +# Python standard library +import functools +import json +import os.path + +# Triton +from ..common import get_arch +import triton +import triton.language as tl + +# AITER +from .pid_preprocessing import pid_grid, remap_xcd + + +# Kernel config. +# ------------------------------------------------------------------------------ + + +@functools.lru_cache() +def get_config( + gmm_type: str, + M: int, + K: int, + N: int, + G: int, + accumulate: bool = False, + trans_rhs: bool = False, +) -> dict[str, int]: + assert gmm_type in { + "gmm", + "ptgmm", + "nptgmm", + }, f"'{gmm_type}' is an invalid GMM variant." + if not hasattr(get_config, "_config_dict"): + dev = get_arch() + config_filename = os.path.join(os.path.dirname(__file__), f"configs/{dev}-GMM.json") + assert os.path.exists(config_filename) and os.path.isfile( + config_filename + ), f"'{config_filename}' isn't an existent file." + with open(config_filename, "r") as config_file: + get_config._config_dict = json.load(config_file) + assert all( + gmm_type in get_config._config_dict + for gmm_type in {"gmm", "ptgmm", "nptgmm"} + ), "Not all GMM variants are present in the configuration file." + + # Heuristic-based config selection for gmm + fwd = gmm_type == "gmm" and trans_rhs + if fwd: + k_n_ratio = K / N if N > 0 else 1.0 + n_k_ratio = N / K if K > 0 else 1.0 + + # Prioritize small shapes first (before ratio checks) + if M < 10000 and (N <= 2048 or K <= 2048): + key = "tiny_shapes" + # Very large M with small N (e.g., 3M+ x 1408) + elif M >= 300000 and N <= 2048: + key = "very_large_m" + # Small shapes (M < 50k, small N) + elif M < 50000 and N <= 2816: + key = "small_shapes" + # Small M with moderate N (e.g., 49k x 2048) + elif M < 100000 and N <= 2048: + key = "small_m_moderate_n" + # K-heavy: K >> N (e.g., 32768x16384x6144) + elif k_n_ratio > 2.0: + key = "k_heavy" + # N-heavy: N >> K (e.g., 32768x4096x14336) + elif n_k_ratio > 2.0: + key = "n_heavy" + # Balanced with large N (e.g., 32768x6144x16384) + elif K < 8192 and N >= 10000: + key = "balanced_large_n" + else: + key = "default" + + bwd = gmm_type == "gmm" and not trans_rhs + if bwd: + k_n_ratio = K / N if N > 0 else 1.0 + + # Prioritize small shapes first (before ratio checks) + if M < 10000 and (N <= 2048 or K <= 2048): + key = "tiny_shapes" + # Very large M with small N (e.g., 393k x 1408) + elif M >= 300000 and N <= 2048: + key = "very_large_m_small_n_bwd" + # K >> N (e.g., 32768x16384x6144, 32768x14336x4096) + elif k_n_ratio > 2.5: + key = "k_heavy_bwd" + # N >> K with high G (e.g., 32768x5120x20480, G=16-64) + elif N / K > 3.0 and G >= 16: + key = "n_very_heavy_bwd" + # Balanced or slightly N-heavy with large N (e.g., 32768x6144x16384) + elif N >= 14000 and k_n_ratio < 2.0: + key = "balanced_large_n_bwd" + # Very small K relative to N (e.g., 32768x4096x14336) + elif K < 5000 and N > 10000: + key = "small_k_large_n_bwd" + else: + key = "default" + + # Heuristic-based config selection for ptgmm + elif gmm_type == "ptgmm": + if accumulate: + key = "accumulate" + else: + # Pattern observed from benchmarks: + # - G >= 32: use smaller BLOCK_SIZE_K (128) with more warps (8) + # - N <= 1408: use smaller BLOCK_SIZE_N (128) + # - Otherwise: use larger blocks (256) with fewer warps (4) + + high_group_count = G >= 32 + small_n = N <= 1408 + + if high_group_count and small_n: + key = "small_n_high_group" + elif high_group_count: + key = "high_group_count" + elif small_n: + key = "small_n" + else: + key = "default" + + # Heuristic-based config selection for nptgmm + elif gmm_type == "nptgmm": + if accumulate: + key = "accumulate" + else: + # Pattern observed from benchmarks: + # - N <= 1408: use BLOCK_SIZE_N=128 with num_warps=8 + # - Otherwise: use default BLOCK_SIZE_K=256, BLOCK_SIZE_N=256, num_warps=4 + # NPTGMM is simpler and less sensitive to G than PTGMM + + small_n = N <= 1408 + + if small_n: + key = "small_n" + else: + key = "default" + + assert ( + key in get_config._config_dict[gmm_type] + ), f"Configuration key '{key}' is absent for {gmm_type}." + return get_config._config_dict[gmm_type][key] + + +# Common code shared by GMM and TGMM kernels. +# ------------------------------------------------------------------------------ + + +# XCD remapping followed by 1D PID to 2D grid mapping. +@triton.jit +def _remap_xcd_tile_grid( + tile_in_mm, + num_row_tiles, + num_col_tiles, + GROUP_SIZE: tl.constexpr = 1, + NUM_XCDS: tl.constexpr = 8, +): + return pid_grid( + remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS), + num_row_tiles, + num_col_tiles, + GROUP_SIZE_M=GROUP_SIZE, + ) + + +# GMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics( + { + "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"] + == 0, + } +) +@triton.jit +def gmm_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_RHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + USE_BIAS: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input row of lhs and output row of out. Each group reads some rows of + # lhs and writes some rows to out. + last_m = 0 + + # Loop through all (m, K, N) MM problems: + # (m, K) x (K, N) = (m, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M) + # num_m_tiles can be zero if group is empty + tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0") + + num_tiles = num_m_tiles * num_n_tiles + # num_tiles can be zero if group is empty + tl.device_assert(num_tiles >= 0, "num_tiles < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_m, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_m = ( + tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + ) % m + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + + lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :] + + if TRANS_RHS: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] + + offs_rhs_n[None, :] * K + ) + else: + rhs_ptrs = ( + rhs_ptr + + g.to(tl.int64) * K * N + + offs_k[:, None] * N + + offs_rhs_n[None, :] + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if K_DIVISIBLE_BY_BLOCK_SIZE_K: + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + else: + k_mask_limit = K - k * BLOCK_SIZE_K + lhs = tl.load( + lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0 + ) + rhs = tl.load( + rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0 + ) + + acc += tl.dot(lhs, rhs, input_precision="ieee") + + lhs_ptrs += BLOCK_SIZE_K + + if TRANS_RHS: + rhs_ptrs += BLOCK_SIZE_K + else: + rhs_ptrs += BLOCK_SIZE_K * N + + # Add bias if enabled + if USE_BIAS: + offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n + bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + # Convert bias to float32 to match accumulator precision + bias = bias.to(tl.float32) + # Broadcast bias across M dimension and add in float32 + acc += bias[None, :] + + # Convert to output dtype after all computations + acc = acc.to(out_ptr.type.element_ty) + + offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :] + ) + + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N), + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.jit +def tgmm_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + GRID_DIM: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + num_tiles = num_k_tiles * num_n_tiles + tl.device_assert(num_tiles > 0, "num_tiles <= 0") + + # Current tile. Each program computes multiple tiles of each group. + tile = tl.program_id(0) + tl.device_assert(tile >= 0, "tile < 0 (at initialization)") + + # Tile limit of last MM problem (inclusive). + last_mm_tile = 0 + + # Last input column of lhs and input row of rhs. Each group reads some + # columns of lhs and some rows of rhs. + last_m = 0 + + # Loop through all (K, m, N) MM problems: + # (K, m) x (m, N) = (K, N) + # sum(m) = M + for g in range(G): + # Get m dimension of current MM problem. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty + tl.device_assert(m >= 0, "m < 0") + + # Loop through tiles of current MM problem. + while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + # Figure out tile coordinates in current MM problem. + tile_in_mm = tile - last_mm_tile + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + # Do regular MM: + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K + ) + else: + lhs_ptrs = ( + lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :]) + ) + + rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc += tl.dot(lhs, rhs, input_precision="ieee") + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum( + lhs, axis=1 + ) # Sum across M dimension [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc += tl.dot(lhs, rhs, input_precision="ieee") + + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + + g.to(tl.int64) * K * N + + offs_out_k[:, None] * N + + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add( + bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed" + ) + + # Go to the next tile by advancing number of programs. + tile += GRID_DIM + tl.device_assert(tile > 0, "tile <= 0 (at update)") + + # Get ready to go to the next MM problem. + + last_mm_tile += num_tiles + # last_mm_tile can be zero if group 0 is skipped + tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)") + + last_m += m + # last_m can be zero if group 0 is skipped + tl.device_assert(last_m >= 0, "last_m < 0 (at update)") + tl.device_assert(last_m <= M, "last_m > M (at update)") + + +# Regular non-persistent TGMM kernel. +# ------------------------------------------------------------------------------ + + +@triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])}) +@triton.jit +def tgmm_non_persistent_kernel( + # Tensor pointers: + lhs_ptr, + rhs_ptr, + group_sizes_ptr, + out_ptr, + bias_grad_ptr, + # Tensor shapes: + M: int, + K: int, + N: int, + G: int, + # Meta-parameters: + TRANS_LHS: tl.constexpr, + BLOCK_SIZE_G: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE: tl.constexpr, + COMPUTE_BIAS_GRAD: tl.constexpr, + ACCUMULATE: tl.constexpr, +): + tl.assume(M > 0) + tl.assume(K > 0) + tl.assume(N > 0) + tl.assume(G > 0) + + # Get group ID from grid. + g = tl.program_id(0) + tl.device_assert(g >= 0, "g < 0") + tl.device_assert(g < G, "g >= G") + + # Get m dimension of current MM group. + m = tl.load(group_sizes_ptr + g) + # m can be zero if group is empty. + tl.device_assert(m >= 0, "m < 0") + + # Skip empty groups. + if m == 0: + return + + # Compute sum(group_sizes) until current group g. + # It's the starting column of lhs and starting row of rhs. + offs_g = tl.arange(0, BLOCK_SIZE_G) + group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0) + start_m = tl.sum(group_sizes) + + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0") + + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0") + + # Get MM tile from grid. + tile_in_mm = tl.program_id(1) + tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") + + tile_k, tile_n = _remap_xcd_tile_grid( + tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE + ) + + tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0") + tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0") + + offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K + offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + if TRANS_LHS: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K + else: + lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :]) + + rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :] + + loop_m = tl.cdiv(m, BLOCK_SIZE_M) + m_divisible_by_block_m = m % BLOCK_SIZE_M == 0 + if not m_divisible_by_block_m: + loop_m -= 1 + + acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32) + # Initialize bias accumulator + bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for _ in range(0, loop_m): + lhs = tl.load(lhs_ptrs) + rhs = tl.load(rhs_ptrs) + + acc += tl.dot(lhs, rhs, input_precision="ieee") + + # Accumulate for bias gradient: sum lhs across M dimension + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K] + + if TRANS_LHS: + lhs_ptrs += BLOCK_SIZE_M * K + else: + lhs_ptrs += BLOCK_SIZE_M + + rhs_ptrs += BLOCK_SIZE_M * N + + if not m_divisible_by_block_m: + offs_lhs_k = ( + tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + ) % K + offs_rhs_n = ( + tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) % N + offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0) + rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0) + acc += tl.dot(lhs, rhs, input_precision="ieee") + # Accumulate last chunk for bias gradient + if COMPUTE_BIAS_GRAD and tile_n == 0: + bias_acc += tl.sum(lhs, axis=1) + + acc = acc.to(out_ptr.type.element_ty) + + offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out_ptrs = ( + out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :] + ) + + mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N) + if ACCUMULATE: + # Load existing values and add to them (like beta=1 in BLAS) + old_vals = tl.load(out_ptrs, mask=mask, other=0.0) + tl.store(out_ptrs, acc + old_vals, mask=mask) + else: + # Overwrite output (like beta=0 in BLAS) + tl.store(out_ptrs, acc, mask=mask) + + # Store bias gradient (only for first N tile, sum across all M) + if COMPUTE_BIAS_GRAD and tile_n == 0: + # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics) + bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k + # Use atomic add since multiple K-tiles may write to same expert's bias + tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed") \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py new file mode 100644 index 000000000..40fe2880a --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py @@ -0,0 +1,564 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# This file is from aiter project (https://github.com/ROCm/aiter) +# commit:04dc719, directory: aiter/ops/triton/gmm.py + +# Imports. +# ------------------------------------------------------------------------------ + +# PyTorch +import math +import torch +from torch import Tensor + +# Triton +import triton + +# AITER: GMM utility functions +from .gmm_common import ( + DTYPE, + is_power_of_2, + check_input_device_dtype, + check_bias_shape_stride, + get_gmm_shape, + get_gmm_output, + get_gmm_transposition, + get_tgmm_shape, + get_tgmm_output, + get_tgmm_bias_grad, + get_tgmm_transposition, +) + +# AITER: GMM Triton kernels +from .gmm_kernels import ( + gmm_kernel, + tgmm_persistent_kernel, + tgmm_non_persistent_kernel, + get_config, +) + + +# GMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _gmm_grid( + N: int, + block_size_m: int, + block_size_n: int, + group_sizes_list: list[int], + grid_dim: int, +) -> tuple[int]: + # Pure CPU arithmetic - ZERO syncs! + num_n_tiles = math.ceil(N / block_size_n) + num_tiles = sum( + math.ceil(gs / block_size_m) for gs in group_sizes_list + ) * num_n_tiles + num_programs = min(grid_dim, num_tiles) + + return (num_programs,) + + +def gmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias: Tensor | None = None, + group_sizes_list: list[int] = None, +) -> Tensor: + """ + Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias + + lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of + rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as + follows for a given group g: + out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g] + + The size of each group, and their respective start and end positions are specified by + group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular + case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and + ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of + just the 10th (last) row of lhs. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (M, K). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 3D input tensor. Shape: (G, K, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (M, N), its data type must match preferred_element_type + and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias : torch.Tensor or None, optional + Optional bias tensor. Shape: (G, N). + If provided, bias data type must match lhs and rhs data type, and bias must be on the same + device as other input tensors. Each group g adds bias[g] to the output. + + Returns + ------- + torch.Tensor + The computed output 2D tensor. Shape: (M, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - GMM is implemented with a persistent Triton kernel. + - lhs must be row-major (lhs.stride() == (K, 1)). + - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() == + (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful + for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True, + this is useful for computing the lhs derivative in the backward pass, while fusing the + transposition. + - out must be row-major (out.stride() == (N, 1)). + - bias must be row-major (bias.stride() == (N, 1)) if provided. + """ + use_bias = bias is not None + check_input_device_dtype(lhs, rhs, group_sizes, bias) + + M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes) + + if use_bias: + check_bias_shape_stride(bias, G, N) + + out = get_gmm_output( + M, + N, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_rhs, _ = get_gmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("gmm", M, K, N, G, accumulate=False, trans_rhs=trans_rhs) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid GMM kernel config." + + group_sizes_list = group_sizes_list if group_sizes_list is not None else group_sizes.tolist() + + grid = _gmm_grid( + N, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + group_sizes_list, + config["GRID_DIM"], + ) + + # fmt: off + gmm_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_RHS=trans_rhs, + USE_BIAS=use_bias, + **config, + ) + # fmt: on + + return out + + +# Persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _ptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, + grid_dim: int, +) -> tuple[int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles = G * num_k_tiles * num_n_tiles + assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}." + num_programs = min(grid_dim, num_tiles) + assert num_programs > 0, f"num_programs must be positive, it's {num_programs}." + return (num_programs,) + + +def ptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'p' in the operator name means that it is implemented with a persistent kernel. There is + also the non-persistent variation, which is implemented with a regular kernel. Please take a + look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or + the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - PTGMM is implemented with a persistent Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + if config is None: + config = get_config("ptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + "GRID_DIM", + } + ), "Invalid PTGMM kernel config." + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + grid = _ptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + config["GRID_DIM"], + ) + + # fmt: off + tgmm_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out + + +# Regular non-persistent TGMM PyTorch wrapper. +# ------------------------------------------------------------------------------ + + +def _nptgmm_grid( + K: int, + N: int, + G: int, + block_size_k: int, + block_size_n: int, +) -> tuple[int, int]: + assert K > 0, f"K must be positive, it's {K}." + assert N > 0, f"N must be positive, it's {N}." + assert G > 0, f"G must be positive, it's {G}." + assert is_power_of_2( + block_size_k + ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})." + assert is_power_of_2( + block_size_n + ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})." + num_k_tiles = triton.cdiv(K, block_size_k) + assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}." + num_n_tiles = triton.cdiv(N, block_size_n) + assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}." + num_tiles_per_mm = num_k_tiles * num_n_tiles + assert ( + num_tiles_per_mm > 0 + ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}." + return (G, num_tiles_per_mm) + + +def nptgmm( + lhs: Tensor, + rhs: Tensor, + group_sizes: Tensor, + preferred_element_type: torch.dtype = DTYPE, + existing_out: Tensor | None = None, + config: dict[str, int] | None = None, + bias_grad: Tensor | None = None, + accumulate: bool = False, +) -> Tensor: + """ + Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs + + lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with + the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch + parlance, it can be implemented as follows for a given group g: + out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :] + + The 't' in the operator name derives from MaxText implementation + (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py), + which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor + shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N). + + The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular + kernel. There is also the persistent variation, which is implemented with a persistent kernel. + Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation, + choosing one or the other is a matter of performance for the target workload. + + Parameters + ---------- + lhs : torch.Tensor + Left-hand side 2D input tensor. Shape: (K, M). + lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type. + lhs must be on the same device of rhs and group_sizes. + rhs : torch.Tensor + Right-hand side 2D input tensor. Shape: (M, N). + rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type. + rhs must be on the same device of lhs and group_sizes. + group_sizes : torch.Tensor + 1D input tensor describing group sizes. Shape: (G,). + group_sizes data type must be torch.int32 and all its elements must be non-negative. + group_sizes must be on the same device of lhs and rhs. + preferred_element_type : torch.dtype, optional + Desired data type for output tensor. Default is torch.bfloat16. + Supported output types are torch.float16 and torch.bfloat16. + existing_out : torch.Tensor or None, optional + Preallocated output tensor. Default is None. + If provided, results are written into this tensor. Otherwise, a new output tensor is + allocated. + If provided then it must have shape (G, K, N), its data type must match + preferred_element_type and it must be on the same device of other input tensors. + config : dict[str, int] or None, optional + Optional dictionary with kernel metaparameters. If absent, config will be queried from + internal tuning database. + bias_grad : torch.Tensor or None, optional + Optional bias gradient output tensor. Shape: (G, K). + If provided, the kernel will compute the bias gradient and write it to this tensor. + bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), + accumulate : bool, optional + Whether to accumulate into existing output tensor values. Default is False. + If False, output will be overwritten with fresh computation. + If True, results will be added to existing output tensor values. + + Returns + ------- + torch.Tensor + The computed output 3D tensor. Shape: (G, K, N). + Output tensor data type is given by preferred_element_type. + If existing_out is provided then existing_out is also returned. + + Implementation Notes + -------------------- + - NPTGMM is implemented with a non-persistent regular Triton kernel. + - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs + is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel + parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward + pass, while fusing the transposition. + - rhs must be row-major (rhs.stride() == (N, 1)). + - out must be row-major (out.stride() == (K * N, N, 1)). + """ + check_input_device_dtype(lhs, rhs, group_sizes) + + M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes) + + out = get_tgmm_output( + K, + N, + G, + device=lhs.device, + preferred_element_type=preferred_element_type, + existing_out=existing_out, + ) + + trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out) + + # Bias gradient handling. + # ----------------------- + # Get or validate bias gradient tensor. + compute_bias_grad = bias_grad is not None + bias_grad_ptr = get_tgmm_bias_grad( + K, + G, + device=lhs.device, + existing_bias_grad=bias_grad, + ) + + if config is None: + config = get_config("nptgmm", M, K, N, G, accumulate) + + assert all( + key in config + and isinstance(config[key], int) + and ( + is_power_of_2(config[key]) + if key.startswith("BLOCK_SIZE_") + else config[key] > 0 + ) + for key in { + "BLOCK_SIZE_M", + "BLOCK_SIZE_K", + "BLOCK_SIZE_N", + "GROUP_SIZE", + } + ), "Invalid NPTGMM kernel config." + + grid = _nptgmm_grid( + K, + N, + G, + config["BLOCK_SIZE_K"], + config["BLOCK_SIZE_N"], + ) + + # fmt: off + tgmm_non_persistent_kernel[grid]( + # Tensor pointers: + lhs, rhs, group_sizes, out, bias_grad_ptr, + # Tensor shapes: + M, K, N, G, + # Meta-parameters: + TRANS_LHS=trans_lhs, + COMPUTE_BIAS_GRAD=compute_bias_grad, + ACCUMULATE=accumulate, + **config, + ) + # fmt: on + + return out \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/gmm/pid_preprocessing.py b/transformer_engine/pytorch/triton_kernels/gmm/pid_preprocessing.py new file mode 100644 index 000000000..4996b425e --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/pid_preprocessing.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: MIT + +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +# This file is from aiter project (https://github.com/ROCm/aiter) +# commit:8dfaf0a, directory: aiter/ops/triton/utils/_triton/pid_preprocessing.py + +import triton +import triton.language as tl + + +@triton.jit +def remap_xcd_chunked( + pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2 +): + # Compute current XCD and local PID + xcd = pid % NUM_XCDS + # distribute the modulo pids in round robin + if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE): + return pid + local_pid = pid // NUM_XCDS + # Calculate chunk index and position within chunk + chunk_idx = local_pid // CHUNK_SIZE + pos_in_chunk = local_pid % CHUNK_SIZE + # Calculate new PID + new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk + return new_pid + + +@triton.jit +def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8): + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + return pid + + +@triton.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: tl.constexpr: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + return pid_m, pid_n \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py new file mode 100644 index 000000000..86384fb69 --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +"""Triton kernels for grouped GEMM""" + +import triton +import triton.language as tl +import torch +from typing import Iterable, Optional, Tuple, Union, List +import functools +import json +import os.path +import sys +from pathlib import Path + +from .gmm.gmm_wrapper import gmm, ptgmm, nptgmm +import transformer_engine_torch as tex + +def general_grouped_gemm_triton( + A: List[torch.Tensor], + B: List[torch.Tensor], + out: List[torch.Tensor], + out_dtype: torch.dtype, + workspaces: List[torch.Tensor], + layout: str = "TN", + m_splits: torch.Tensor = None, + gelu: bool = False, + grad=False, + accumulate: bool = False, + bias: Optional[List[torch.Tensor]] = None, + use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, + single_output=False, + **kwargs, +) -> list: + """ + Drop-in replacement for general_grouped_gemm using AITER's Triton kernels. + + Supports: + - Forward pass (layout="TN"): C = B @ A^T (where A=weights, B=inputs, C=outputs) + - Backward pass dgrad (layout="NN", grad=True): C = B @ A (where A=weights, B=grad_output, C=dgrad) + - Backward pass wgrad (layout="NT", grad=True): C = B^T @ A (where A=inputs, B=grad_output, C=wgrad) + + Args: + A: Left-hand side matrices (weights for forward/dgrad, inputs for wgrad) + B: Right-hand side matrices (inputs for forward, grad_outputs for backward) + out: Output matrices (pre-allocated) + out_dtype: Output dtype + workspace: Workspace tensor (unused, for compatibility) + single_output: Whether to produce single concatenated output + m_splits: List of token counts per expert + bias: List of bias tensors (optional) + use_bias: Whether to apply bias + use_split_accumulator: Unused, for compatibility + layout: "TN" for forward pass, "NN" for dgrad backward pass, "NT" for wgrad backward pass + grad: True for backward pass + accumulate: Whether to accumulate into C (for wgrad only) + + Returns: + Tuple of (outputs, bias_or_grad_bias, gelu_input) to match C++ backend signature + - bias_or_grad_bias: List of bias/grad_bias tensors (or list of bias if passed in) + """ + assert m_splits is not None, "m_splits required for Triton kernel" + assert len(out) > 0, "Output tensor(s) must be pre-allocated and passed in C list" + + # Determine operation type + is_dgrad = (layout == "NN" and grad) + is_wgrad = (layout == "NT" and grad) + + + if is_wgrad: + # WGRAD: ptgmm expects lhs=(K,M), rhs=(M,N), out=(G,K,N) + # A=inputs (list of (m_i, in_features)), B=grad_outputs (list of (m_i, out_features)) + A_tensor = A[0] if len(A) == 1 else torch.cat(A, dim=0) # (M, in_features) + B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, out_features) + out_tensor_3d = out # (G, out_features, in_features) + + # Allocate bias_grad OUTPUT buffer if needed (kernel writes to this) + bias_grad_tensor = None + if use_bias: + G = m_splits.shape[0] + K = B_tensor.shape[1] # out_features + bias_grad_tensor = torch.zeros(G, K, dtype=torch.float32, device=B_tensor.device) + + # Backward pass: C = B^T @ A (wgrad = grad_output^T @ input) + # ptgmm expects lhs shape (K, M), so we need to transpose + ptgmm( + lhs=B_tensor.t(), # (out_features, M) - transpose to get correct shape + rhs=A_tensor, # (M, in_features) + group_sizes=m_splits, + preferred_element_type=out_dtype, + existing_out=out_tensor_3d, # (G, out_features, in_features) + config=None, + bias_grad=bias_grad_tensor, # OUTPUT: (G, out_features) or None + accumulate=accumulate, + ) + + # Convert bias_grad to list to match C++ backend signature + if use_bias and bias_grad_tensor is not None: + grad_biases = list(torch.unbind(bias_grad_tensor, dim=0)) + else: + grad_biases = [None] * len(out) if bias is None else bias + + # Return appropriate output format + return_out = out_tensor_3d.view(-1, out_tensor_3d.shape[-1]) if single_output else out + return return_out, grad_biases, None + + elif is_dgrad: + # DGRAD: gmm expects lhs=(M,K), rhs=(G,K,N), out=(M,N) + # A=weights (list of (out_features, in_features)), B=grad_outputs (list of (m_i, out_features)) + A_tensor_3d = torch.stack(A, dim=0) # (G, out_features, in_features) + B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, out_features) + out_tensor = out[0] if len(out) == 1 else torch.cat(out, dim=0) # (M, in_features) + + # Stack bias into 3D if provided + bias_tensor = None + if bias is not None and len(bias) > 0 and bias[0].numel() > 0: + bias_tensor = torch.stack(bias, dim=0) # (G, in_features) + + # Backward pass: C = B @ A (dgrad = grad_output @ weight) + gmm( + lhs=B_tensor, # (M, out_features) + rhs=A_tensor_3d, # (G, out_features, in_features) + group_sizes=m_splits, + preferred_element_type=out_dtype, + existing_out=out_tensor, # (M, in_features) + config=None, + bias=bias_tensor, + group_sizes_list=kwargs.get("m_splits_list", []), + ) + + grad_biases = [None] * len(m_splits) if bias is None else bias + return_out = out_tensor if single_output else out + return return_out, grad_biases, None + + else: + # FORWARD: gmm expects lhs=(M,K), rhs=(G,K,N), out=(M,N) + # Forward pass: C = B @ A^T (output = input @ weight^T + bias) + # A=weights (list of (out_features, in_features)), B=inputs (list of (m_i, in_features)) + A_tensor_3d = torch.stack(A, dim=0) # (G, out_features, in_features) + A_tensor_3d = A_tensor_3d.transpose(1, 2) # (G, in_features, out_features) for TN layout + B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, in_features) + out_tensor = out[0] if len(out) == 1 else torch.cat(out, dim=0) # (M, out_features) + + # Stack bias into 3D if provided + bias_tensor = None + if bias is not None and len(bias) > 0 and bias[0].numel() > 0: + bias_tensor = torch.stack(bias, dim=0) # (G, out_features) + + gmm( + lhs=B_tensor, # (M, in_features) + rhs=A_tensor_3d, # (G, in_features, out_features) + group_sizes=m_splits, + preferred_element_type=out_dtype, + existing_out=out_tensor, # (M, out_features) + config=None, + bias=bias_tensor, + group_sizes_list=kwargs.get("m_splits_list", []), + ) + + grad_biases = [None] * len(m_splits) if bias is None else bias + return_out = out_tensor if single_output else out + return return_out, grad_biases, None