diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index a7aabbbcb6..6aad0876f0 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -88,16 +88,16 @@ struct TestParams { std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; case ShapeCase::kSameFirst: // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; case ShapeCase::kSameLast: // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; } } @@ -247,6 +247,8 @@ void run_grouped_gemm_case(const TestParams& params) { nullptr, // config (use defaults) 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + // Compare results for (size_t i = 0; i < num_gemms; ++i) { Tensor grouped_split("grouped_D" + std::to_string(i), std::vector{static_cast(std::get<0>(shapes[i])), @@ -289,7 +291,6 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo kTestParams = { // Basic tests - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 9dd965fa94..684ada47b7 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -55,7 +55,7 @@ def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: - """Create quantizer for given quantization scheme""" + """Create quantizers for given quantization scheme""" if quantization == "fp8_delayed_scaling": quantizer = Float8Quantizer( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 8e3b0517ee..0dc8cd8a3d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -46,7 +46,12 @@ is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -1993,6 +1998,82 @@ def test_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("single_weight", [True, False], ids=["single_weight", "multi_weight"]) +def test_grouped_linear_m_splits_tensor(single_weight): + """Test GroupedLinear with m_splits as torch tensor (no_quantization/bf16). + grouped_tensor_path is chosen and must match reference (single_weight vs reference model, + or multi_weight list m_splits vs tensor m_splits). + """ + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + num_gemms = 3 + in_features = 32 + out_features = 64 + m_splits = torch.tensor([5, 7, 9], device="cuda", dtype=torch.int64) + m_splits_list = [5, 7, 9] + dtype = torch.bfloat16 + m_total = int(m_splits.sum().item()) + + reference_model = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=False, + ) + with torch.no_grad(): + ref_weights = [getattr(reference_model, f"weight{i}") for i in range(num_gemms)] + + test_model = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=single_weight, + ) + with torch.no_grad(): + if single_weight: + for i, w in enumerate(test_model.grouped_weight_storage.split_into_quantized_tensors()): + w.copy_(ref_weights[i]) + else: + for i in range(num_gemms): + getattr(test_model, f"weight{i}").copy_(ref_weights[i]) + + inp = torch.randn(m_total, in_features, device="cuda", dtype=dtype, requires_grad=True) + inp_ref = inp.detach().clone().requires_grad_() + + if single_weight: + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits) + else: + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits_list) + + torch.testing.assert_close(out, out_ref, **dtype_tols(dtype)) + + out.sum().backward() + out_ref.sum().backward() + + torch.testing.assert_close(inp.grad, inp_ref.grad, **dtype_tols(dtype)) + if single_weight: + ref_wgrad = torch.cat( + [getattr(reference_model, f"weight{i}").grad.view(-1) for i in range(num_gemms)] + ) + torch.testing.assert_close( + getattr(test_model, "weight0").grad, ref_wgrad, **dtype_tols(dtype) + ) + + @pytest.mark.skipif( torch.cuda.get_device_capability() != (9, 0), reason="Only enable CUTLASS grouped gemm on Hopper", @@ -2792,6 +2873,126 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + if grouped_tensor.rowwise_data is None: + raise RuntimeError("GroupedTensor rowwise_data is not initialized.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + grouped_tensor.rowwise_data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False]) +def test_grouped_gemm_grouped_tensor(layout, accumulate): + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = (4, 512, 256, 256) + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + dtype = torch.bfloat16 + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + general_grouped_gemm( + A, + B, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + + device = A[0].device + + def _make_grouped_tensor_from_splits(m_sizes, last_dim): + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + def _make_grouped_tensor_uniform(num_tensors, first_dim, last_dim): + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + if layout == "TN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n) + elif layout == "NN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k) + else: # layout == "NT" + grouped_A = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_uniform(z, n, k) + _pack_grouped_tensor(grouped_A, A) + _pack_grouped_tensor(grouped_B, B) + _pack_grouped_tensor(grouped_out, out) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_grouped = grouped_out.split_into_quantized_tensors() + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index dc4757ab90..077b4659fb 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" @@ -138,7 +139,6 @@ struct GroupedGemmSetupWorkspace { offset += ptr_size; ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - // Int arrays for storage dimensions (4-byte aligned) ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a37f1c2d4d..22364aaf90 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,6 +5,7 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import ctypes import os import functools import torch @@ -14,6 +15,11 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage + +try: + from ..tensor.storage.grouped_tensor import GroupedTensor +except ModuleNotFoundError: # Backward compatibility with old import paths + from ..tensor.grouped_tensor import GroupedTensor from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -22,6 +28,7 @@ __all__ = [ "general_gemm", "general_grouped_gemm", + "general_grouped_gemm_for_grouped_tensor", ] @@ -284,3 +291,94 @@ def general_grouped_gemm( ) return out, bias, gelu_input + + +def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: + """Return workspace size for grouped GEMM pointer setup. + Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. + """ + ptr_bytes = ctypes.sizeof(ctypes.c_void_p) + int_bytes = ctypes.sizeof(ctypes.c_int) + ptr_size = num_tensors * ptr_bytes + int_size = num_tensors * int_bytes + k_ptr_alignment = 16 + aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment + size = 8 * aligned_ptr_size + 6 * int_size + alignment = 256 + return ((size + alignment - 1) // alignment) * alignment + + +def general_grouped_gemm_for_grouped_tensor( + A, + B, + out, + *, + layout: str = "TN", + accumulate: bool = False, + alpha: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Grouped GEMM using GroupedTensor inputs. + + This uses nvte_grouped_gemm and supports different per-matrix shapes. + + The caller must ensure that GroupedTensor metadata is already compatible with the + underlying GEMM implementation (e.g., aligned offsets and output metadata layout). + """ + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + transa = layout[0] == "T" + transb = layout[1] == "T" + + num_tensors = A.num_tensors + assert A.num_tensors == B.num_tensors == out.num_tensors, ( + f"GroupedTensor num_tensors must match: A={A.num_tensors}, B={B.num_tensors}," + f" out={out.num_tensors}" + ) + + if out.data is not None: + device = out.data.device + elif out.columnwise_data is not None: + device = out.columnwise_data.device + else: + raise ValueError("Output GroupedTensor must have allocated data.") + + if alpha is None: + alpha = torch.ones(num_tensors, dtype=torch.float32, device=device) + if beta is None: + if accumulate: + beta = torch.ones(num_tensors, dtype=torch.float32, device=device) + else: + beta = torch.zeros(num_tensors, dtype=torch.float32, device=device) + + if not alpha.is_cuda or not beta.is_cuda: + raise ValueError("alpha and beta must be CUDA tensors.") + + workspace_setup = torch.empty( + get_grouped_gemm_setup_workspace_size(num_tensors), + dtype=torch.uint8, + device=device, + ) + workspace_cublas = torch.empty( + get_cublas_workspace_size_bytes(), + dtype=torch.uint8, + device=device, + ) + + sm_count = get_sm_count() + sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) + + C = out + return tex.te_general_grouped_gemm_for_grouped_tensor( + A, + transa, + B, + transb, + C, + out, + alpha, + beta, + workspace_setup, + workspace_cublas, + sm_count, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..f18eceb0a7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -149,6 +149,13 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index d75b0f14c7..d5a8ff5489 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -570,4 +570,72 @@ std::optional> te_general_grouped_gemm( return bias; } +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + std::optional grouped_C = std::nullopt; + if (!C.is_none()) { + grouped_C = GroupedTensorFromPyTorchGroupedTensor(C); + } + + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + if (grouped_C.has_value()) { + NVTE_CHECK(grouped_C->num_tensors() == num_tensors, + "Grouped GEMM requires C to have the same num_tensors as inputs."); + } + + NVTE_CHECK(alpha.numel() == static_cast(num_tensors), + "Grouped GEMM expects alpha to have num_tensors elements."); + NVTE_CHECK(beta.numel() == static_cast(num_tensors), + "Grouped GEMM expects beta to have num_tensors elements."); + + auto te_alpha = makeTransformerEngineTensor(alpha); + auto te_beta = makeTransformerEngineTensor(beta); + + auto te_workspace_setup = makeTransformerEngineTensor( + workspace_setup.data_ptr(), std::vector{static_cast(workspace_setup.numel())}, + DType::kByte); + auto te_workspace_cublas = makeTransformerEngineTensor( + workspace_cublas.data_ptr(), + std::vector{static_cast(workspace_cublas.numel())}, DType::kByte); + + std::optional config; + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, + grouped_C.has_value() ? grouped_C->data() : nullptr, grouped_D.data(), + te_alpha.data(), te_beta.data(), te_workspace_setup.data(), + te_workspace_cublas.data(), + config.has_value() ? static_cast(*config) : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(D); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8302a13010..5585d0d93b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -276,6 +276,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("te_general_grouped_gemm_for_grouped_tensor", + &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, + "Grouped GEMM for GroupedTensor"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index eda5e8fc54..36c4f39818 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -216,8 +216,8 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); // Rowwise data - if (!tensor.attr("data").is_none()) { - const auto &data = tensor.attr("data").cast(); + if (!tensor.attr("rowwise_data").is_none()) { + const auto &data = tensor.attr("rowwise_data").cast(); DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..a0e0f28996 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,12 @@ # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from torch._tensor import Tensor +from typing import Any, Union, Optional, Callable, Tuple, List from itertools import chain +from torch.distributed.tensor import DTensor + import warnings import functools @@ -39,6 +43,7 @@ ) from ..cpp_extensions import ( general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -57,6 +62,48 @@ __all__ = ["GroupedLinear"] +def _clone_grouped_tensor_with_data( + grouped_tensor: GroupedTensor, data: torch.Tensor, dtype: torch.dtype +) -> GroupedTensor: + return GroupedTensor( + num_tensors=grouped_tensor.num_tensors, + shape=grouped_tensor.shape, + quantizer=grouped_tensor.quantizer, + dtype=dtype, + data=data, + columnwise_data=grouped_tensor.columnwise_data, + scale_inv=grouped_tensor.scale_inv, + columnwise_scale_inv=grouped_tensor.columnwise_scale_inv, + amax=grouped_tensor.amax, + columnwise_amax=grouped_tensor.columnwise_amax, + scale=grouped_tensor.scale, + first_dims=grouped_tensor.first_dims, + last_dims=grouped_tensor.last_dims, + tensor_offsets=grouped_tensor.tensor_offsets, + offsets=grouped_tensor.offsets, + scale_inv_offsets=grouped_tensor.scale_inv_offsets, + columnwise_scale_inv_offsets=grouped_tensor.columnwise_scale_inv_offsets, + logical_shape=grouped_tensor.logical_shape, + ) + + +def _make_grouped_tensor_for_m_splits(data: torch.Tensor, m_splits: torch.Tensor) -> GroupedTensor: + # Use data.shape[0] to avoid first_dims.sum().item() D2H copy (breaks CUDA graph) + logical_first_dim = data.shape[0] + grouped = GroupedTensor.make_grouped_tensor( + num_tensors=int(m_splits.numel()), + first_dims=m_splits, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=data.shape[-1], + quantizer=None, + device=data.device, + dtype=data.dtype, + ) + grouped.data = data.contiguous().view(-1) + return grouped + + class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module Calls custom cuda extensions. @@ -76,6 +123,7 @@ def forward( # to reduce CPU overhead due to pytorch arg checking. ( m_splits, + m_splits_is_tensor, use_bias, is_first_microbatch, fp8, @@ -97,10 +145,11 @@ def forward( save_original_input, debug, ) = non_tensor_args - - num_gemms = len(m_splits) - weights = weights_and_biases[:num_gemms] - biases = weights_and_biases[num_gemms:] + num_weight_params = module.num_weight_params + num_gemms = int(m_splits.numel()) if m_splits_is_tensor else len(m_splits) + logical_first_dim = inp.shape[0] if m_splits_is_tensor else sum(m_splits) + weights = weights_and_biases[:num_weight_params] + biases = weights_and_biases[num_weight_params:] device = inp.device weight_requires_grad = weights[0].requires_grad @@ -134,9 +183,11 @@ def forward( if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) + no_quantization = not fp8 and weight_quantizers[0] is None # Initialize input tensors - in_features = weights[0].size(-1) + in_features = module.in_features + out_features = module.out_features if inp.size(-1) != in_features: raise ValueError( f"Input tensor (shape={tuple(inp.size())}) is not compatible with " @@ -144,6 +195,14 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + inp_view_cast = None + if m_splits_is_tensor and not no_quantization: + # TODO: Support this path. + raise ValueError( + "GroupedGEMM with grouped tensor path with quantization is not supported yet." + ) + grouped_tensor_path = no_quantization and m_splits_is_tensor + if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, @@ -159,7 +218,10 @@ def forward( inp_view, input_quantizers, m_splits, activation_dtype ) else: - inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + inp_view_cast = cast_if_needed(inp_view, activation_dtype) + inputmats = ( + [inp_view_cast] if grouped_tensor_path else torch.split(inp_view_cast, m_splits) + ) if cpu_offloading: start_offload(*inputmats) @@ -170,7 +232,7 @@ def forward( # FP8 cast to workspace buffer weights_fp8 = [] update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): + for i in range(num_weight_params): weight_fp8 = module.get_weight_workspace( tensor=weights[i], quantizer=weight_quantizers[i], @@ -191,7 +253,7 @@ def forward( biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases # Initialize output tensor out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], + [logical_first_dim, out_features], dtype=activation_dtype, device=device, ) @@ -203,19 +265,35 @@ def forward( if hasattr(recipe, "fp8_gemm_fprop"): use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - # Perform GEMM - general_grouped_gemm( - weights_fp8, - inputmats, - [out], - output_quantizers, - activation_dtype, - single_output=True, - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=use_split_accumulator, - ) + if grouped_tensor_path: + grouped_weight = _clone_grouped_tensor_with_data( + module.grouped_weight_storage, + cast_if_needed(module.grouped_weight_storage.data, activation_dtype), + activation_dtype, + ) + grouped_input = _make_grouped_tensor_for_m_splits(inputmats[0], m_splits) + grouped_out = _make_grouped_tensor_for_m_splits(out, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_input, + grouped_out, + layout="TN", + accumulate=False, + ) + else: + # Perform GEMM + general_grouped_gemm( + weights_fp8, + inputmats, + [out], + output_quantizers, + activation_dtype, + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=use_split_accumulator, + ) if fp8_calibration: for i in range(num_gemms): @@ -230,7 +308,10 @@ def forward( if is_grad_enabled: ctx.weight_quantizers = weight_quantizers - ctx.weights_shape_1 = weights[0].shape[1] + if module.single_weight: + ctx.weights_shape_1 = module.in_features + else: + ctx.weights_shape_1 = weights[0].shape[1] # TODO: update after #1638 is merged. # pylint: disable=fixme if weight_requires_grad: @@ -265,7 +346,6 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.grad_input_quantizers = grad_input_quantizers ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_weight_quantizers = grad_weight_quantizers @@ -277,17 +357,22 @@ def forward( # the main_grad buffer lazily before backprop if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + ctx.main_grad_funcs = [ + weights[i].get_main_grad for i in range(num_weight_params) + ] else: ctx.main_grad_funcs = [ - lambda j=i: weights[j].main_grad for i in range(num_gemms) + lambda j=i: weights[j].main_grad for i in range(num_weight_params) ] else: - ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] + ctx.main_grad_funcs = [lambda: None for i in range(num_weight_params)] ctx.device = device ctx.output_quantizers = output_quantizers ctx.m_splits = m_splits + ctx.logical_first_dim = logical_first_dim + ctx.grouped_tensor_path = grouped_tensor_path ctx.num_gemms = num_gemms + ctx.num_weight_params = num_weight_params ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -296,6 +381,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel + ctx.in_features = module.in_features + ctx.out_features = module.out_features ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False @@ -308,7 +395,10 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + ctx.single_weight = module.single_weight + ctx.grouped_weight_storage = ( + module.grouped_weight_storage if grouped_tensor_path else None + ) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -316,8 +406,9 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): + m_splits = ctx.m_splits saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - N = ctx.num_gemms + N = ctx.num_weight_params inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] @@ -367,7 +458,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( grad_output_view, @@ -378,10 +469,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Only split grad output. Grad bias is fused with # wgrad GEMM. - grad_output = torch.split( - cast_if_needed(grad_output_view, ctx.activation_dtype), - ctx.m_splits, - ) + if ctx.grouped_tensor_path: + out = cast_if_needed(grad_output_view, ctx.activation_dtype) + grad_output = [out] + grouped_grad_output = _make_grouped_tensor_for_m_splits(out, m_splits) + else: + grad_output = torch.split( + cast_if_needed(grad_output_view, ctx.activation_dtype), + ctx.m_splits, + ) if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -399,27 +495,39 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], recipe.fp8_gemm_dgrad.use_split_accumulator ) dgrad = torch.empty( - (sum(ctx.m_splits), ctx.weights_shape_1), + ctx.inp_shape, dtype=ctx.activation_dtype, device=ctx.device, ) + # Make sure weights are available in column-wise format # for dgrad computation. for weight in weights: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) - general_grouped_gemm( - weights, - grad_output, - [dgrad], - ctx.grad_input_quantizers, - ctx.activation_dtype, - single_output=True, - layout="NN", - m_splits=ctx.m_splits, - grad=True, - use_split_accumulator=dgrad_gemm_use_split_accumulator, - ) + if ctx.grouped_tensor_path: + grouped_weight = ctx.grouped_weight_storage + grouped_dgrad = _make_grouped_tensor_for_m_splits(dgrad, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_grad_output, + grouped_dgrad, + layout="NN", + accumulate=False, + ) + else: + general_grouped_gemm( + weights, + grad_output, + [dgrad], + ctx.grad_input_quantizers, + ctx.activation_dtype, + single_output=True, + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ) if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD @@ -429,7 +537,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], wgrad_gemm_use_split_accumulator = ( recipe.fp8_gemm_wgrad.use_split_accumulator ) - if ctx.fuse_wgrad_accumulation: + grouped_wgrad = None + if ctx.grouped_tensor_path and ctx.fuse_wgrad_accumulation: + raise NotImplementedError( + "Fused wgrad accumulation is not supported with grouped tensor path." + ) + if ctx.grouped_tensor_path: + # Wgrad GEMM writes one output per group; use num_gemms (not num_weight_params). + num_wgrad_tensors = ctx.num_gemms + grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_wgrad_tensors, + shape=[(ctx.out_features, ctx.in_features)] * num_wgrad_tensors, + quantizer=None, + dtype=ctx.activation_dtype, + device=ctx.device, + ) + if ctx.single_weight: + wgrad_list = [grouped_wgrad.data.view(-1)] + else: + wgrad_list = grouped_wgrad.split_into_quantized_tensors() + elif ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: wgrad_list = [ @@ -461,32 +588,66 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.activation_dtype, ) else: - inputmats = torch.split( - cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + if ctx.grouped_tensor_path: + inputmats = [cast_if_needed(inp_view, ctx.activation_dtype)] + else: + inputmats = torch.split( + cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + ) + + if ctx.grouped_tensor_path: + + def grouped_gemm_wgrad_grouped_tensor(inputmat, grad_output, grouped_wgrad): + grouped_input = _make_grouped_tensor_for_m_splits(inputmat, ctx.m_splits) + grouped_grad_output = _make_grouped_tensor_for_m_splits( + grad_output, ctx.m_splits ) - grouped_gemm_wgrad = functools.partial( - general_grouped_gemm, - quantization_params=ctx.grad_weight_quantizers, - out_dtype=ctx.activation_dtype, - layout="NT", - grad=True, - m_splits=ctx.m_splits, - use_bias=ctx.use_bias if grad_biases[0] is None else None, - bias=biases, - use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=( - accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) - else False - ), - ) + # dW = grad_output^T @ input -> (out_features, m) @ (m, in_features). + # Row-wise: A (m, n) -> cuBLAS (n, m); use A=grad_output, B=input. + # Layout NT: op(A)=(n, m), op(B)^T=(m, k) -> D = (n, k). + general_grouped_gemm_for_grouped_tensor( + grouped_grad_output, + grouped_input, + grouped_wgrad, + layout="NT", + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) + return None, [None] * ctx.num_weight_params, None + + grouped_gemm_wgrad = grouped_gemm_wgrad_grouped_tensor + else: + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, + quantization_params=ctx.grad_weight_quantizers, + out_dtype=ctx.activation_dtype, + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + elif ctx.grouped_tensor_path: + # Pass 2D view so _make_grouped_tensor_for_m_splits gets correct logical_last_dim + grad_output_2d = grad_output[0].view(ctx.logical_first_dim, ctx.out_features) + # wgrad_list shares the same memory with grouped_wgrad + grouped_gemm_wgrad(inputmats[0], grad_output_2d, grouped_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): if grad_biases[i] is None: grad_biases[i] = grad_biases_[i] del grad_biases_ @@ -523,14 +684,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): for weight, wgrad in zip(origin_weights, wgrad_list) ] else: - wgrad_list = [None] * ctx.num_gemms + wgrad_list = [None] * (ctx.num_weight_params) if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() and not ctx.fp8 ): - grad_biases = [None] * ctx.num_gemms + grad_biases = [None] * (ctx.num_weight_params) if ctx.reduce_and_update_bwd_fp8_tensors: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -630,6 +791,7 @@ def __init__( save_original_input: bool = False, single_grouped_parameter: bool = False, name: Optional[str] = None, + single_weight: bool = False, ) -> None: super().__init__(name) @@ -638,6 +800,9 @@ def __init__( self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if single_weight: + bias = False + return_bias = False self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias @@ -694,14 +859,30 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - for i in range(self.num_gemms): - # Construct weight parameter + if self.single_weight and self.primary_weights_in_fp8: + raise ValueError("Single weight is only supported for High precision weights.") + + if self.single_weight: + shape_weight = [(self.out_features * self.num_gemms * self.in_features,)] + shape_bias = [(self.out_features * self.num_gemms,)] + param_names = ["weight0", "bias0"] + self.num_weight_params = 1 + num_tensors = 1 + else: + shape_weight = [(self.out_features, self.in_features) for _ in range(self.num_gemms)] + shape_bias = [self.out_features for _ in range(self.num_gemms)] + num_tensors = self.num_gemms + param_names = [f"weight{i}" for i in range(self.num_gemms)] + [ + f"bias{i}" for i in range(self.num_gemms) + ] + self.num_weight_params = self.num_gemms + + for i in range(num_tensors): self.register_parameter( f"weight{i}", torch.nn.Parameter( torch.empty( - self.out_features, - self.in_features, + shape_weight[i], device=device, dtype=self.params_dtype, ), @@ -717,7 +898,7 @@ def __init__( f"bias{i}", torch.nn.Parameter( torch.empty( - self.out_features, + shape_bias[i], device=device, dtype=self.params_dtype, ), @@ -736,9 +917,8 @@ def __init__( if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): - for i in range(self.num_gemms): - if name in (f"weight{i}", f"bias{i}"): - param.skip_backward_post_hook = True + if name in param_names: + param.skip_backward_post_hook = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -757,6 +937,17 @@ def make_grouped_weights(self, defer_init=False) -> None: if defer_init: return + if self.single_weight: + weight = getattr(self, "weight0") + logical_shape = (self.num_gemms * self.out_features, self.in_features) + self.grouped_weight_storage = GroupedTensor( + num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features) for _ in range(self.num_gemms)], + quantizer=None, + dtype=self.params_dtype, + data=weight, + logical_shape=logical_shape, + ) weight_quantizers = self._get_weight_quantizers() recipe = ( weight_quantizers[0]._get_compatible_recipe() @@ -777,7 +968,7 @@ def make_grouped_weights(self, defer_init=False) -> None: dtype=self.params_dtype, device=weights[0].device, ) - + self.grouped_weight_storage = grouped_weights # Copy existing params into storage. with torch.no_grad(): for i in range(self.num_gemms): @@ -833,12 +1024,10 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: # Set parallelism attributes for linear biases if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if self.parallel_mode == "row": setattr( - getattr(self, f"bias{i}"), - "sequence_parallel", - self.sequence_parallel, + getattr(self, f"bias{i}"), "sequence_parallel", self.sequence_parallel ) elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) @@ -847,7 +1036,7 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: def forward( self, inp: torch.Tensor, - m_splits: List[int], + m_splits: Union[List[int], torch.Tensor], is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -857,8 +1046,8 @@ def forward( ---------- inp : torch.Tensor Input tensor. - m_splits : List[int] - List of integers representing the split of the input tensor. + m_splits : List[int] | torch.Tensor + List of integers or a device tensor representing the split of the input tensor. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -874,18 +1063,18 @@ def forward( produced) """ debug = self.is_debug_iter() - assert not isinstance( inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." - assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." - + m_splits_is_tensor = torch.is_tensor(m_splits) + num_splits = m_splits.numel() if m_splits_is_tensor else len(m_splits) + assert num_splits == self.num_gemms, "Number of splits should match number of GEMMs." is_grad_enabled = torch.is_grad_enabled() inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_weight_params)] quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() @@ -912,6 +1101,7 @@ def forward( non_tensor_args = ( m_splits, + m_splits_is_tensor, self.apply_bias, is_first_microbatch, self.fp8, @@ -955,10 +1145,10 @@ def backward_dw(self): weight_params = self._get_weight_tensors() bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if bias_params[i].grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ @@ -1029,9 +1219,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: self.quantizers["scaling_fwd"][ self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] ] - for i in range(self.num_gemms) + for i in range(self.num_weight_params) ] - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_quantizers[i].internal = not self.primary_weights_in_fp8 return weight_quantizers