From 5213efb047e8c1196ea3f9d3b9fc9f23d0068087 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 3 Mar 2026 05:51:57 +0000 Subject: [PATCH] Cleanup and squash Signed-off-by: Kirthi Shankar Sivamani --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_cumsum.cu | 92 ++++++ .../test_nvfp4_group_quantize_graph_safe.py | 2 +- tests/pytorch/test_grouped_tensor.py | 75 ++--- tests/pytorch/test_sanity.py | 130 +------- transformer_engine/common/common.cu | 55 ++++ .../transformer_engine/transformer_engine.h | 13 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/misc.cpp | 28 ++ .../pytorch/csrc/extensions/pybind.cpp | 18 +- transformer_engine/pytorch/csrc/pybind.h | 1 + transformer_engine/pytorch/csrc/quantizer.cpp | 108 ++++--- .../pytorch/module/grouped_linear.py | 56 ++-- transformer_engine/pytorch/tensor/__init__.py | 6 + .../pytorch/tensor/grouped_tensor.py | 205 +++++++++++++ .../pytorch/tensor/storage/__init__.py | 2 +- ...ed_tensor.py => grouped_tensor_storage.py} | 280 +++++++++--------- 17 files changed, 698 insertions(+), 376 deletions(-) create mode 100755 tests/cpp/operator/test_cumsum.cu create mode 100644 transformer_engine/pytorch/tensor/grouped_tensor.py rename transformer_engine/pytorch/tensor/storage/{grouped_tensor.py => grouped_tensor_storage.py} (82%) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..94a4ebb1cf 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable(test_operator test_act.cu test_normalization.cu test_normalization_mxfp8.cu + test_cumsum.cu test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu diff --git a/tests/cpp/operator/test_cumsum.cu b/tests/cpp/operator/test_cumsum.cu new file mode 100755 index 0000000000..201c1ca171 --- /dev/null +++ b/tests/cpp/operator/test_cumsum.cu @@ -0,0 +1,92 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +std::vector reference_cumsum_with_leading_zero(const std::vector &input) { + std::vector output(input.size() + 1, 0); + for (size_t i = 0; i < input.size(); ++i) { + output[i + 1] = output[i] + input[i]; + } + return output; +} + +void run_cumsum_test(const std::vector &h_input) { + const size_t n = h_input.size(); + auto h_expected = reference_cumsum_with_leading_zero(h_input); + std::vector h_output(n + 1, 0); + + int64_t *d_input = nullptr; + int64_t *d_output = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_input, n * sizeof(int64_t))); + NVTE_CHECK_CUDA(cudaMalloc(&d_output, (n + 1) * sizeof(int64_t))); + + NVTE_CHECK_CUDA( + cudaMemcpy(d_input, h_input.data(), n * sizeof(int64_t), cudaMemcpyHostToDevice)); + nvte_cumsum(d_input, d_output, n, 0 /* stream */); + NVTE_CHECK_CUDA( + cudaMemcpy(h_output.data(), d_output, (n + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + NVTE_CHECK_CUDA(cudaFree(d_input)); + NVTE_CHECK_CUDA(cudaFree(d_output)); + + ASSERT_EQ(h_output.size(), h_expected.size()); + for (size_t i = 0; i < h_output.size(); ++i) { + EXPECT_EQ(h_output[i], h_expected[i]) << "Mismatch at output index " << i; + } +} + +std::vector make_input(size_t n) { + std::vector input(n); + for (size_t i = 0; i < n; ++i) { + // Deterministic signed values in [-3, 3]. + input[i] = static_cast(i % 7) - 3; + } + return input; +} + +std::vector cumsum_test_sizes = { + 1, + 2, + 17, + 256, + 257, + 513, + 1024, +}; + +} // namespace + +TEST(CumsumTest, KnownValues) { + const std::vector input = {3, -1, 4, 0, -5}; + run_cumsum_test(input); +} + +class CumsumSizeTestSuite : public ::testing::TestWithParam {}; + +TEST_P(CumsumSizeTestSuite, TestCumsumBySize) { + run_cumsum_test(make_input(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, CumsumSizeTestSuite, ::testing::ValuesIn(cumsum_test_sizes), + [](const testing::TestParamInfo &info) { + return "N" + std::to_string(info.param); + }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index 1e62f91eb8..8d81d578a7 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -10,7 +10,7 @@ from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor import pytest import torch diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index ad08c0474d..ad20e9422a 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -8,7 +8,7 @@ import pytest import torch import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch import ( Quantizer, Float8Quantizer, @@ -125,8 +125,8 @@ def test_basic_construction_all_same_shape(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=None, + shapes=shape, + quantizers=None, device="cuda", dtype=torch.float32, ) @@ -147,8 +147,8 @@ def test_basic_construction_varying_first_dim(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=None, + shapes=shape, + quantizers=None, device="cuda", dtype=torch.float32, ) @@ -170,8 +170,8 @@ def test_split_into_quantized_tensors_no_quantization(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=None, + shapes=shape, + quantizers=None, device="cuda", dtype=torch.float32, ) @@ -203,13 +203,14 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shape) + quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizer, + shapes=shape, + quantizers=quantizers, device="cuda", + dtype=torch.float32, ) # Get the original data pointer @@ -236,8 +237,8 @@ def test_split_varying_shapes(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=None, + shapes=shape, + quantizers=None, device="cuda", dtype=torch.float32, ) @@ -260,13 +261,14 @@ def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shape) + quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizer, + shapes=shape, + quantizers=quantizers, device="cuda", + dtype=torch.float32, ) # Get original data pointers before quantization @@ -300,13 +302,14 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 shape = [(256, 512), (512, 512), (768, 512)] - quantizer = make_quantizer(quantization, num_tensors, shape) + quantizers = [make_quantizer(quantization, num_tensors, shape) for _ in range(num_tensors)] grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=quantizer, + shapes=shape, + quantizers=quantizers, device="cuda", + dtype=torch.float32, ) # Get original data pointers @@ -329,38 +332,6 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: assert rowwise_data.data_ptr() == original_data_ptr + expected_offset cumulative_numel += tensor_shape[0] * tensor_shape[1] - @pytest.mark.parametrize("quantization", _quantization_params) - def test_static_quantize_method(self, quantization: str) -> None: - """Test the static quantize method""" - num_tensors = 3 - shape = [(512, 512) for _ in range(num_tensors)] - quantizer = make_quantizer(quantization, num_tensors, shape) - - # Create input tensors - input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] - - # Use static quantize method - grouped_tensor = GroupedTensor.create_and_quantize( - tensors=input_tensors, - quantizer=quantizer, - device="cuda", - ) - - # Verify the grouped tensor was created correctly - assert grouped_tensor.num_tensors == num_tensors - assert grouped_tensor.has_data() - - # Verify quantized_tensors were created and point to same storage - assert grouped_tensor.quantized_tensors is not None - assert len(grouped_tensor.quantized_tensors) == num_tensors - - original_data_ptr = grouped_tensor.data.data_ptr() - for i, qtensor in enumerate(grouped_tensor.quantized_tensors): - rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) - numel = shape[i][0] * shape[i][1] - expected_offset = _rowwise_offset_bytes(i * numel, quantization) - assert rowwise_data.data_ptr() == original_data_ptr + expected_offset - @pytest.mark.parametrize( "shape", [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], @@ -461,8 +432,8 @@ def test_clear(self) -> None: grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, - shape=shape, - quantizer=None, + shapes=shape, + quantizers=None, device="cuda", dtype=torch.float32, ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 3ef8c0983f..384b6774f6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -138,115 +138,21 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"): - """ - Verify that tensors are stored in contiguous memory. - - Args: - tensors: List or iterable of tensors to check - num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4) - tensor_name: Name to use in error messages - """ - tensor_list = list(tensors) - if len(tensor_list) < 2: - return # Nothing to check - - for i in range(1, len(tensor_list)): - prev_tensor = tensor_list[i - 1] - curr_tensor = tensor_list[i] - - # Calculate expected offset based on previous tensor size - prev_numel = prev_tensor.numel() - expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size() - - # Verify current tensor's data pointer is correctly offset - expected_ptr = prev_tensor.data_ptr() + expected_offset - actual_ptr = curr_tensor.data_ptr() - - assert ( - actual_ptr == expected_ptr - ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}" - - -def check_grouped_tensor_pointers( - weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None +def check_grouped_weight( + module: GroupedLinear, num_gemms: int, out_features: int, in_features: int ): """ - Verify that the pointers of the weights are in contiguous memory for GroupedTensor. - TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach. + Verify GroupedLinear exposes one grouped weight parameter with shape + [num_gemms, out_features, in_features]. """ - - num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1 - - # Check data. - if hasattr(weights[0], "_data") and weights[0]._data is not None: - data_tensors = [w._data for w in weights] - check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data") - - # Check transpose. - if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None: - transpose_tensors = [w._transpose for w in weights] - check_grouped_tensor_pointers_helper( - transpose_tensors, num_elems_in_byte=1, tensor_name="transpose" - ) - - # Check scale_inv. - if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None: - scale_inv_tensors = [w._scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv" - ) - - # Check rowwise scale_inv. - if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None: - scale_inv_tensors = [w._rowwise_scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv" - ) - - # Check columnwise scale_inv. - if ( - hasattr(weights[0], "_columnwise_scale_inv") - and weights[0]._columnwise_scale_inv is not None - ): - columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_scale_inv_tensors, - num_elems_in_byte=1, - tensor_name="columnwise scale_inv", - ) - - # Check rowwise amax. - if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None: - rowwise_amax_tensors = [w._rowwise_amax for w in weights] - check_grouped_tensor_pointers_helper( - rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax" - ) - - # Check columnwise amax. - if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None: - columnwise_amax_tensors = [w._columnwise_amax for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax" - ) - - # Check rowwise data. - if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None: - rowwise_data_tensors = [w._rowwise_data for w in weights] - check_grouped_tensor_pointers_helper( - rowwise_data_tensors, - num_elems_in_byte=num_elems_in_a_data_byte, - tensor_name="rowwise data", - ) - - # Check columnwise data. - if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None: - columnwise_data_tensors = [w._columnwise_data for w in weights] - check_grouped_tensor_pointers_helper( - columnwise_data_tensors, - num_elems_in_byte=num_elems_in_a_data_byte, - tensor_name="columnwise data", - ) + weight_params = [(name, p) for name, p in module.named_parameters() if "weight" in name] + assert len(weight_params) == 1, f"Expected 1 grouped weight parameter, got {len(weight_params)}" + name, weight = weight_params[0] + assert name == "weight", f"Expected grouped parameter name 'weight', got {name}" + assert tuple(weight.shape) == (num_gemms, out_features, in_features), ( + "Grouped weight has unexpected shape. " + f"Expected {(num_gemms, out_features, in_features)}, got {tuple(weight.shape)}" + ) def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): @@ -603,9 +509,6 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) - if single_param: - os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1" - if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -620,13 +523,13 @@ def test_sanity_grouped_linear( ffn_hidden_size, bias=use_bias, params_dtype=dtype, + single_grouped_parameter=single_param, ).cuda() - # Verify that weights are stored in contiguous GroupedTensor storage. - weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] + # Verify grouped linear exposes a single grouped weight parameter. if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): if single_param: - check_grouped_tensor_pointers(weights, fp8_recipe) + check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True @@ -645,9 +548,6 @@ def test_sanity_grouped_linear( loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) - if single_param: - del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] - @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 0ec40dc01c..751e08d2c2 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -24,6 +24,50 @@ __global__ void __launch_bounds__(1) reciprocal(scale_inv_ptr, scale); } +constexpr size_t kCumsumThreadsPerBlock = 256; + +__global__ void __launch_bounds__(kCumsumThreadsPerBlock) + cumsum_with_leading_zero_kernel(const int64_t *__restrict__ input, int64_t *__restrict__ output, + size_t num_elements) { + __shared__ int64_t shared_prefix[kCumsumThreadsPerBlock]; + __shared__ int64_t chunk_carry; + + const size_t tid = threadIdx.x; + if (tid == 0) { + chunk_carry = 0; + output[0] = 0; + } + __syncthreads(); + + for (size_t chunk_start = 0; chunk_start < num_elements; chunk_start += blockDim.x) { + const size_t idx = chunk_start + tid; + shared_prefix[tid] = (idx < num_elements) ? input[idx] : 0; + __syncthreads(); + + for (size_t offset = 1; offset < blockDim.x; offset <<= 1) { + int64_t addend = 0; + if (tid >= offset) { + addend = shared_prefix[tid - offset]; + } + __syncthreads(); + shared_prefix[tid] += addend; + __syncthreads(); + } + + if (idx < num_elements) { + output[idx + 1] = chunk_carry + shared_prefix[tid]; + } + __syncthreads(); + + if (tid == 0) { + const size_t remaining = num_elements - chunk_start; + const size_t chunk_size = remaining < blockDim.x ? remaining : blockDim.x; + chunk_carry += shared_prefix[chunk_size - 1]; + } + __syncthreads(); + } +} + } // namespace cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { @@ -116,6 +160,17 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream); MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream); } + +void nvte_cumsum(const int64_t *input, int64_t *output, size_t num_elements, cudaStream_t stream) { + NVTE_API_CALL(nvte_cumsum); + NVTE_CHECK(input != nullptr || num_elements == 0, + "Input pointer for cumsum must be allocated when num_elements > 0."); + NVTE_CHECK(output != nullptr, "Output pointer for cumsum must be allocated."); + + cumsum_with_leading_zero_kernel<<<1, kCumsumThreadsPerBlock, 0, stream>>>(input, output, + num_elements); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // extern "C" void checkCuDriverContext(CUstream stream) { diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e316f8be8c..063cb89f65 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -427,6 +427,19 @@ int nvte_is_non_tn_fp8_gemm_supported(); */ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream); +/*! \brief Computes exclusive cumulative sum into an output buffer with leading zero. + * + * Given `input` of length `num_elements`, writes `output` of length `num_elements + 1`: + * - output[0] = 0 + * - output[i + 1] = sum(input[0..i]) for i in [0, num_elements) + * + * \param[in] input Pointer to input int64 data on device. + * \param[out] output Pointer to output int64 data on device. + * \param[in] num_elements Number of elements in input. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_cumsum(const int64_t *input, int64_t *output, size_t num_elements, cudaStream_t stream); + /*! \brief TE Grouped Tensor type * * NVTEGroupedTensor is a collection of tensors with potentially different shapes diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b2b0751b04..678966817f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -158,6 +158,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); +at::Tensor cumsum(at::Tensor input, std::optional out = std::nullopt); + /*************************************************************************************************** * Activations **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index d667a61d44..cbc5b40571 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -4,7 +4,10 @@ * See LICENSE for license information. ************************************************************************/ +#include + #include "../extensions.h" +#include "pybind.h" namespace transformer_engine::pytorch { @@ -12,4 +15,29 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); } size_t get_cudnn_version() { return cudnnGetVersion(); } +at::Tensor cumsum(at::Tensor input, std::optional out) { + init_extension(); + + // Operate on a contiguous int64 CUDA tensor. + auto contiguous_input = input.contiguous(); + NVTE_CHECK(contiguous_input.is_cuda(), "Expected input to be on CUDA."); + NVTE_CHECK(contiguous_input.scalar_type() == at::kLong, "Expected input dtype to be int64."); + NVTE_CHECK(contiguous_input.dim() == 1, "Expected 1D input tensor."); + + const auto num_elements = static_cast(contiguous_input.numel()); + if (!out) { + out = at::empty({static_cast(num_elements + 1)}, contiguous_input.options()); + } + NVTE_CHECK(out->is_cuda(), "Expected output to be on CUDA."); + NVTE_CHECK(out->scalar_type() == at::kLong, "Expected output dtype to be int64."); + NVTE_CHECK(out->dim() == 1, "Expected 1D output tensor."); + NVTE_CHECK(static_cast(out->numel()) == num_elements + 1, "Expected output length ", + num_elements + 1, " but got ", out->numel(), "."); + NVTE_CHECK(out->is_contiguous(), "Expected output to be contiguous."); + + nvte_cumsum(contiguous_input.data_ptr(), out->data_ptr(), num_elements, + at::cuda::getCurrentCUDAStream()); + return std::move(*out); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b9fc65363d..c05fec1fd8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,8 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; -std::once_flag extension_init_flag; +PyTypeObject *GroupedTensorPythonClass = nullptr; PyTypeObject *GroupedTensorStoragePythonClass = nullptr; +std::once_flag extension_init_flag; void init_float8_extension() { auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); @@ -103,11 +104,17 @@ void init_nvfp4_extensions() { } void init_grouped_tensor_extension() { - if (GroupedTensorStoragePythonClass) return; + if (GroupedTensorPythonClass && GroupedTensorStoragePythonClass) return; auto grouped_tensor_module = - py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor"); - GroupedTensorStoragePythonClass = reinterpret_cast( + py::module_::import("transformer_engine.pytorch.tensor.grouped_tensor"); + GroupedTensorPythonClass = reinterpret_cast( PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); + auto grouped_tensor_storage_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor_storage"); + GroupedTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_storage_module.ptr(), "GroupedTensorStorage")); + NVTE_CHECK(GroupedTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, "Internal error: could not initialize pyTorch grouped tensor extension."); } @@ -275,6 +282,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("cumsum", &transformer_engine::pytorch::cumsum, "Exclusive cumsum with leading zero", + py::arg("input"), py::kw_only(), py::arg("out") = py::none(), + py::call_guard()); m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("compute_amax", &transformer_engine::pytorch::compute_amax, diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 059eb5e3fb..9e640537f9 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -43,6 +43,7 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *GroupedTensorPythonClass; extern PyTypeObject *GroupedTensorStoragePythonClass; void init_extension(); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0da5f69197..fb077e7dba 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -67,17 +67,21 @@ std::optional build_grouped_tensor_offsets(const size_t num_tensors, } const auto& first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.is_cuda(), "first_dims must be on CUDA."); NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, "first_dims must have length ", num_tensors, "."); const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); - auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64; + auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous(); + auto tensor_offsets = + at::empty({static_cast(num_tensors + 1)}, scaled_first_dims.options()); - // Single kernel needed for these ops. - auto cumsum = at::cumsum(scaled_first_dims, 0); - auto zero = at::zeros({1}, cumsum.options()); - return at::cat({zero, cumsum}); + NVTE_SCOPED_GIL_RELEASE({ + nvte_cumsum(scaled_first_dims.data_ptr(), tensor_offsets.data_ptr(), + num_tensors, at::cuda::getCurrentCUDAStream()); + }); + return tensor_offsets; } at::TensorOptions grouped_tensor_data_options(const DType dtype) { @@ -88,6 +92,22 @@ py::object maybe_tensor_to_py(const std::optional& tensor) { return tensor ? py::cast(*tensor) : py::none(); } +py::object make_grouped_quantizers(const py::object& quantizer, const size_t num_tensors) { + if (quantizer.is_none()) { + return py::none(); + } + py::list quantizers; + for (size_t i = 0; i < num_tensors; ++i) { + quantizers.append(quantizer); + } + return std::move(quantizers); +} + +py::handle grouped_tensor_python_class(const bool internal) { + PyTypeObject* cls = internal ? GroupedTensorStoragePythonClass : GroupedTensorPythonClass; + return py::handle(reinterpret_cast(cls)); +} + } // namespace constexpr size_t NVFP4_BLOCK_SIZE = 16; @@ -172,18 +192,19 @@ std::pair NoneQuantizer::create_grouped_tensor getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = py::none(), "columnwise_scale_inv"_a = py::none(), "amax"_a = py::none(), "columnwise_amax"_a = py::none(), "scale"_a = py::none(), "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } @@ -366,19 +387,20 @@ std::pair Float8Quantizer::create_grouped_tens getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, "columnwise_amax"_a = py::none(), "scale"_a = py::none(), "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } @@ -673,19 +695,20 @@ std::pair Float8CurrentScalingQuantizer::creat getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, "columnwise_amax"_a = py::none(), "scale"_a = scale, "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } @@ -1020,19 +1043,20 @@ std::pair Float8BlockQuantizer::create_grouped getTensorShape(*tensor_offsets)); } - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), "columnwise_amax"_a = py::none(), "scale"_a = py::none(), "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } @@ -1425,19 +1449,20 @@ std::pair MXFP8Quantizer::create_grouped_tenso out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), "columnwise_amax"_a = py::none(), "scale"_a = py::none(), "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } @@ -1842,10 +1867,13 @@ std::pair NVFP4Quantizer::create_grouped_tenso out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); - py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::object out_py = GroupedTensorClass( - "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), - "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}, + GetATenDType(dtype), "num_tensors"_a = num_tensors, + "quantizers"_a = make_grouped_quantizers(quantizer, num_tensors), + "data"_a = maybe_tensor_to_py(rowwise_data), "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), @@ -1853,9 +1881,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso "columnwise_amax"_a = maybe_tensor_to_py(columnwise_amax), "scale"_a = py::none(), "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), "last_dims"_a = py::none(), - "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), - "logical_shape"_a = std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none()); return {std::move(out_cpp), std::move(out_py)}; } diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b381073d78..4f34fd23da 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,7 +6,6 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings -import os import functools import torch @@ -14,7 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -595,6 +594,10 @@ class GroupedLinear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. + single_grouped_parameter : bool, default = False + If set to ``True``, grouped weights are stored as a single grouped parameter + instead of one parameter per GEMM. + EXPERIMENTAL and subject to change. Notes ----- @@ -625,6 +628,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, save_original_input: bool = False, + single_grouped_parameter: bool = False, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -641,6 +645,7 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + self.single_grouped_parameter = single_grouped_parameter assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." @@ -767,8 +772,8 @@ def make_grouped_weights(self, defer_init=False) -> None: # Create the weight storage. grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=self.num_gemms, - shape=[(self.out_features, self.in_features)] * self.num_gemms, - quantizer=weight_quantizers[0], + shapes=[(self.out_features, self.in_features)] * self.num_gemms, + quantizers=weight_quantizers, dtype=self.params_dtype, device=weights[0].device, ) @@ -781,22 +786,23 @@ def make_grouped_weights(self, defer_init=False) -> None: else: grouped_weights.quantized_tensors[i].copy_(weights[i]) - # Re-register the grouped weights as parameters. + # Re-register as a single grouped weight parameter. + self.register_parameter( + "weight", + torch.nn.Parameter(grouped_weights), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"], + ) for i in range(self.num_gemms): - self.register_parameter( - f"weight{i}", - torch.nn.Parameter(grouped_weights.quantized_tensors[i]), - init_fn=self.init_method, - get_rng_state_tracker=self.get_rng_state_tracker, - fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], - ) + self.register_parameter(f"weight{i}", None) self.set_tensor_parallel_attributes(defer_init=defer_init) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) # Grouped tensor weights is an opt-in feature. - if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))): + if self.single_grouped_parameter: self.make_grouped_weights(defer_init=defer_init) def set_tensor_parallel_attributes(self, defer_init=False) -> None: @@ -804,13 +810,22 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: if not defer_init: # Set parallelism attributes for linear weights - for i in range(self.num_gemms): + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: set_tensor_model_parallel_attributes( - tensor=getattr(self, f"weight{i}"), + tensor=grouped_weight, is_parallel=True, dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + else: + for i in range(self.num_gemms): + set_tensor_model_parallel_attributes( + tensor=getattr(self, f"weight{i}"), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) # Set parallelism attributes for linear biases if self.use_bias: @@ -933,7 +948,7 @@ def backward_dw(self): with get_nvtx_range_context("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] - weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_params = self._get_weight_tensors() bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): @@ -983,7 +998,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + grouped_weight = getattr(self, "weight", None) + if grouped_weight is not None: + weight_tensors = grouped_weight.quantized_tensors + if weight_tensors is None: + # TODO(ksivaman): Remove this after GEMM integration. + weight_tensors = grouped_weight.split_into_quantized_tensors() + else: + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index cb199d24b5..5668056700 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -17,10 +17,12 @@ from .storage.mxfp8_tensor_storage import MXFP8TensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .storage.grouped_tensor_storage import GroupedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer +from .grouped_tensor import GroupedTensor from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ @@ -35,11 +37,13 @@ "MXFP8TensorStorage", "Float8BlockwiseQTensorStorage", "NVFP4TensorStorage", + "GroupedTensorStorage", "QuantizedTensor", "Float8Tensor", "MXFP8Tensor", "Float8BlockwiseQTensor", "NVFP4Tensor", + "GroupedTensor", "prepare_for_saving", "restore_from_saved", ] @@ -89,5 +93,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensorStorage, NVFP4Tensor, NVFP4TensorStorage, + GroupedTensor, + GroupedTensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py new file mode 100644 index 0000000000..4ed8417956 --- /dev/null +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -0,0 +1,205 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +from torch.utils._pytree import tree_map + +from ..quantized_tensor import QuantizedTensorStorage, Quantizer +from .storage.grouped_tensor_storage import GroupedTensorStorage + + +# For now, conservatively ban all shape manipulatimg ops. +BANNED_SHAPE_OPS = { + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.reshape.default, + torch.ops.aten._reshape_alias.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.unflatten.int, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.transpose.int, + torch.ops.aten.permute.default, + torch.ops.aten.movedim.int, + torch.ops.aten.t.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.narrow.default, + torch.ops.aten.select.int, + torch.ops.aten.split.Tensor, + torch.ops.aten.chunk.default, + torch.ops.aten.expand.default, + torch.ops.aten.expand_as.default, + torch.ops.aten.cat.default, + torch.ops.aten.stack.default, +} + + +class GroupedTensor(GroupedTensorStorage, torch.Tensor): + """Tensor wrapper class for grouped tensor storage.""" + + def __new__( + cls, + shape: Tuple[int, int], + dtype: torch.dtype, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizers: Optional[List[Optional[Quantizer]]] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + ): + del quantizers + del offsets + del scale_inv_offsets + del columnwise_scale_inv_offsets + + if ( + shapes is not None + and len(shapes) == num_tensors + and num_tensors > 0 + and all(shapes[0] == s for s in shapes) + ): + wrapper_shape = (num_tensors, shapes[0][0], shapes[0][1]) + else: + wrapper_shape = shape + + device = None + for maybe_tensor in ( + data, + columnwise_data, + scale_inv, + columnwise_scale_inv, + amax, + columnwise_amax, + scale, + first_dims, + last_dims, + tensor_offsets, + ): + if maybe_tensor is not None: + device = maybe_tensor.device + break + if device is None: + device = torch.device("cuda") + + strides = [1] * len(wrapper_shape) + for i in range(len(wrapper_shape) - 2, -1, -1): + strides[i] = strides[i + 1] * wrapper_shape[i + 1] + return torch.Tensor._make_wrapper_subclass( + cls, + wrapper_shape, + strides=tuple(strides), + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=False, + device=device, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + """Dispatch by dequantizing grouped members, then requantizing writes.""" + if kwargs is None: + kwargs = {} + + # Parameter construction calls detach()/alias-like paths. + if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): + return args[0] + + # Don't allow reshape/view etc. + if func in BANNED_SHAPE_OPS: + raise RuntimeError(f"{cls.__name__} forbids shape-manipulation op: {func} ") + + def grouped_to_stacked_tensor(grouped: GroupedTensor) -> torch.Tensor: + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + dequantized_members = [ + ( + member.dequantize(dtype=grouped.get_dtype()) + if isinstance(member, QuantizedTensorStorage) + else member + ) + for member in grouped_members + ] + return torch.stack(dequantized_members, dim=0) + + def maybe_unwrap(arg): + if isinstance(arg, GroupedTensor): + return grouped_to_stacked_tensor(arg) + return arg + + def update_grouped_tensor_inplace(grouped: GroupedTensor, updated: torch.Tensor): + if not grouped.all_same_shape(): + raise NotImplementedError( + "GroupedTensor __torch_dispatch__ currently supports only uniform member shapes" + ) + updated_members = list(updated.unbind(dim=0)) + if grouped.quantizers is None: + grouped_members = grouped.quantized_tensors + if grouped_members is None: + grouped_members = grouped.split_into_quantized_tensors() + for dst, src in zip(grouped_members, updated_members): + dst.copy_(src) + else: + grouped.quantize(updated_members) + + def maybe_update_inplace(arg, new_arg, schema_arg): + if ( + isinstance(arg, GroupedTensor) + and isinstance(new_arg, torch.Tensor) + and hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ): + update_grouped_tensor_inplace(arg, new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) + + # In-place op: dequantize members, perform op, write back into grouped storage. + if func._schema.is_mutable: + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op: operate on dequantized stacked tensors. + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + return super().__torch_dispatch__(func, types, new_args, new_kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # Do not force GroupedTensor on outputs. + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index 7c8a014c1d..44a77d975f 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,4 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 -from .grouped_tensor import GroupedTensor # noqa: F401 +from .grouped_tensor_storage import GroupedTensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py similarity index 82% rename from transformer_engine/pytorch/tensor/storage/grouped_tensor.py rename to transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index bf5792ffc9..08d937d300 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -2,13 +2,12 @@ # # See LICENSE for license information. -"""Grouped tensor class for handling collections of tensors with different shapes""" +"""Grouped tensor storage class for handling collections of tensors with different shapes""" from __future__ import annotations from typing import Optional, Tuple, List, Union import math import torch - from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ..mxfp8_tensor import MXFP8Tensor @@ -21,7 +20,7 @@ from .nvfp4_tensor_storage import NVFP4TensorStorage -class GroupedTensor: +class GroupedTensorStorage: """ EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. @@ -51,10 +50,11 @@ class GroupedTensor: def __init__( self, + shape: Tuple[int, int], + dtype: torch.dtype, num_tensors: int, - shape: Optional[List[Tuple[int, int]]] = None, - quantizer: Optional[Quantizer] = None, - dtype: Optional[torch.dtype] = None, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizers: Optional[List[Optional[Quantizer]]] = None, data: Optional[torch.Tensor] = None, columnwise_data: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, @@ -68,15 +68,16 @@ def __init__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, - logical_shape: Optional[Tuple[int, int]] = None, ) -> None: """ Initialize a GroupedTensor. Args: + shape: 2D tuple representing conceptual shape + dtype: Data type of the grouped tensor num_tensors: Number of tensors in the group - shape: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for the grouped tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizers: Quantizers for each tensor in the group (len num_tensors) data: Row-wise data buffer (1D flattened) columnwise_data: Column-wise data buffer (1D flattened) scale_inv: Row-wise scale inverse buffer @@ -88,17 +89,14 @@ def __init__( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) offsets: Vector of integer offsets for each tensor. - logical_shape: 2D tuple representing conceptual shape """ self.num_tensors = num_tensors - self.quantizer = quantizer - self.shape = shape - self.dtype = ( - dtype if dtype is not None else torch.float32 - ) # Default to float32 if not provided + self.quantizers = quantizers + self.tensor_shapes = shapes + self.fake_dtype = dtype # Data buffers - self.data = data + self.rowwise_data = data self.columnwise_data = columnwise_data self.scale_inv = scale_inv self.columnwise_scale_inv = columnwise_scale_inv @@ -132,7 +130,7 @@ def __init__( # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) # Represents how the 1D flattened data should be interpreted as 2D # Always 2D with positive dimensions - self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + self.logical_shape = shape # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. # Used as a convenience. @@ -145,7 +143,7 @@ def has_data(self) -> bool: Returns: True if data buffer is initialized, False otherwise """ - return self.data is not None + return self.rowwise_data is not None def has_columnwise_data(self) -> bool: """ @@ -239,14 +237,13 @@ def get_dtype(self) -> torch.dtype: The high precision dtype of the data buffer """ - return self.dtype + return self.fake_dtype def clear(self) -> None: """ Reset tensor data and clear all buffers. """ - self.shape = None - self.data = None + self.rowwise_data = None self.columnwise_data = None self.scale_inv = None self.columnwise_scale_inv = None @@ -258,54 +255,39 @@ def clear(self) -> None: self.tensor_offsets = None self.logical_shape = (0, 0) self.num_tensors = 0 - self.quantizer = None + self.quantizers = None self.quantized_tensors = None self.offsets = None self.scale_inv_offsets = None self.columnwise_scale_inv_offsets = None + self.tensor_shapes = [] + self.fake_dtype = torch.float32 def __repr__(self) -> str: """String representation of the GroupedTensor.""" return ( f"GroupedTensor(num_tensors={self.num_tensors}, " - f"shape={self.shape}, " + f"shapes={self.tensor_shapes}, " f"logical_shape={self.logical_shape}, " + f"quantizers={self.quantizers}, " f"dtype={self.get_dtype()})" ) - def __str__(self) -> str: - """User-friendly string representation.""" - shape_info = [] - if self.all_same_shape(): - shape_info.append("uniform shape") - else: - if not self.all_same_first_dim(): - shape_info.append("varying first dim") - if not self.all_same_last_dim(): - shape_info.append("varying last dim") - - return ( - f"GroupedTensor with {self.num_tensors} tensors " - f"({', '.join(shape_info) if shape_info else 'uniform'}), " - f"logical_shape={self.logical_shape}, " - f"dtype={self.get_dtype()}" - ) - @staticmethod def make_grouped_tensor_with_shapes( num_tensors: int, - shape: List[Tuple[int, int]], - quantizer: Optional[Quantizer] = None, + shapes: List[Tuple[int, int]], + quantizers: Optional[List[Optional[Quantizer]]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """ Create a GroupedTensor for storing multiple weight tensors of the same shape. Args: num_tensors: Number of tensors - shape: 2D shape of each tensor (len num_tensors) - quantizer: Quantizer for each tensor + shapes: 2D shape of each tensor (len num_tensors) + quantizers: Quantizers for each tensor (len num_tensors) device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -314,26 +296,26 @@ def make_grouped_tensor_with_shapes( """ # First dim - first_dim_list = [s[0] for s in shape] + first_dim_list = [s[0] for s in shapes] uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) logical_first_dim = sum(first_dim_list) if uniform_first_dim: first_dims = None else: - first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + first_dims = torch.tensor([s[0] for s in shapes], dtype=torch.int64, device=device) # Last dim - last_dim_list = [s[1] for s in shape] + last_dim_list = [s[1] for s in shapes] logical_last_dim = last_dim_list[0] assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" - return GroupedTensor.make_grouped_tensor( + return GroupedTensorStorage.make_grouped_tensor( num_tensors=num_tensors, first_dims=first_dims, last_dims=None, logical_first_dim=logical_first_dim, logical_last_dim=logical_last_dim, - quantizer=quantizer, + quantizers=quantizers, device=device, dtype=dtype, ) @@ -345,10 +327,10 @@ def make_grouped_tensor( last_dims: Optional[torch.Tensor], logical_first_dim: int, logical_last_dim: int, - quantizer: Optional[Quantizer] = None, + quantizers: Optional[List[Optional[Quantizer]]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """ Create a GroupedTensor for storing multiple weight tensors of the same shape. @@ -358,8 +340,8 @@ def make_grouped_tensor( last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) logical_first_dim: Logical first dimension logical_last_dim: Logical last dimension - quantizer: Quantizer for each tensor - Used to figure out the recipe and what to allocate. + quantizers: Quantizers for each tensor (len num_tensors) + Used to figure out the recipe and what to allocate. device: Device to allocate tensors on, defaults to current cuda device dtype: Data type of the tensor (for high precision case) @@ -415,10 +397,23 @@ def make_grouped_tensor( # Calculate logical shape based logical_shape = (logical_first_dim, logical_last_dim) - no_quantization = quantizer is None - - rowwise_usage = quantizer.rowwise_usage if not no_quantization else True - columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + no_quantization = quantizers is None or all(q is None for q in quantizers) + reference_quantizer = None + if not no_quantization: + if len(quantizers) != num_tensors: + raise ValueError(f"Expected {num_tensors} quantizers, got {len(quantizers)}") + if any(q is None for q in quantizers): + raise ValueError("quantizers must contain one quantizer per tensor when provided") + reference_quantizer = quantizers[0] + if any( + type(q._get_compatible_recipe()) + is not type(reference_quantizer._get_compatible_recipe()) + for q in quantizers + ): + raise ValueError("All quantizers must have the same recipe for GroupedTensor") + + rowwise_usage = reference_quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = reference_quantizer.columnwise_usage if not no_quantization else False # Calculate total elements across all tensors total_elements = logical_first_dim * logical_last_dim @@ -441,7 +436,7 @@ def make_grouped_tensor( if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8) columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) - elif quantizer._get_compatible_recipe().mxfp8(): + elif reference_quantizer._get_compatible_recipe().mxfp8(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) data = torch.empty(total_elements, dtype=torch.uint8, device=device) @@ -450,7 +445,7 @@ def make_grouped_tensor( total_scale_elements = 0 scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = reference_quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements scale_inv_offsets.append(total_scale_elements) @@ -463,14 +458,14 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = reference_quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - elif quantizer._get_compatible_recipe().delayed(): + elif reference_quantizer._get_compatible_recipe().delayed(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) data = torch.empty(total_elements, dtype=torch.uint8, device=device) @@ -489,7 +484,7 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) - elif quantizer._get_compatible_recipe().nvfp4(): + elif reference_quantizer._get_compatible_recipe().nvfp4(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -499,7 +494,7 @@ def make_grouped_tensor( total_scale_elements = 0 scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = reference_quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) @@ -515,7 +510,7 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + columnwise_scale_inv_shape = reference_quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( @@ -523,7 +518,7 @@ def make_grouped_tensor( ) # Columnwise amax buffer - one per tensor columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) - elif quantizer._get_compatible_recipe().float8_block_scaling(): + elif reference_quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) data = torch.empty(total_elements, dtype=torch.uint8, device=device) @@ -532,7 +527,7 @@ def make_grouped_tensor( total_scale_elements = 0 scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = reference_quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) @@ -544,13 +539,13 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + columnwise_scale_inv_shape = reference_quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.float32, device=device ) - elif quantizer._get_compatible_recipe().float8_current_scaling(): + elif reference_quantizer._get_compatible_recipe().float8_current_scaling(): # Current scaling - per-tensor scaling computed on the fly if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) @@ -572,13 +567,28 @@ def make_grouped_tensor( scale = torch.empty(num_tensors, dtype=torch.float32, device=device) amax = torch.empty(num_tensors, dtype=torch.float32, device=device) else: - raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + raise ValueError(f"Unsupported quantizer for GroupedTensor: {reference_quantizer}") + + # Construct wrapper vs storage based on quantizer.internal. + # If quantizers is None (high precision path), default to wrapper class. + # TODO(ksivaman): Properly handle high precision path. + if quantizers is None or all(q is None for q in quantizers): + internal = False + else: + internal = quantizers[0].internal + if internal: + grouped_tensor_class = GroupedTensorStorage + else: + from ..grouped_tensor import GroupedTensor + + grouped_tensor_class = GroupedTensor - grouped_tensor = GroupedTensor( + grouped_tensor = grouped_tensor_class( + logical_shape, + dtype, num_tensors=num_tensors, - shape=shape, - dtype=dtype, - quantizer=quantizer, + shapes=shape, + quantizers=quantizers, data=data, columnwise_data=columnwise_data, scale_inv=scale_inv, @@ -592,7 +602,6 @@ def make_grouped_tensor( offsets=offsets, scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, - logical_shape=logical_shape, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() @@ -603,11 +612,11 @@ def split_into_quantized_tensors( ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: """ Split the GroupedTensor into a list of `num_tensors` - quantized tensors based on the quantizer. No additional memory allocation is performed, + quantized tensors based on the per-tensor quantizers. No additional memory allocation is performed, so the tensors returned are the same as the ones used to create the GroupedTensor. - If quantizer is None, returns normal torch tensors. - If quantizer.internal is True, returns QuantizedTensorStorage. + If quantizers is None, returns normal torch tensors. + If quantizers[i].internal is True, returns QuantizedTensorStorage. Otherwise, returns QuantizedTensor. This API is NOT graph safe, but can be used for testing & debugging. @@ -618,10 +627,10 @@ def split_into_quantized_tensors( result = [] - no_quantization = self.quantizer is None + no_quantization = self.quantizers is None or all(q is None for q in self.quantizers) - # if self.shape is None, then trigger D2H copy and get the shape (not graph safe) - if self.shape is None: + # if self.tensor_shapes is None, then trigger D2H copy and get the shape (not graph safe) + if self.tensor_shapes is None: first_dims_list = ( [self.logical_shape[0]] * self.num_tensors if self.first_dims is None @@ -635,7 +644,7 @@ def split_into_quantized_tensors( shape_list = [] for i in range(self.num_tensors): shape_list.append((first_dims_list[i], last_dims_list[i])) - self.shape = shape_list + self.tensor_shapes = shape_list # edge case: handle the case where tensor_offsets is given but offsets is not set if self.offsets is None and self.tensor_offsets is not None: @@ -645,7 +654,7 @@ def split_into_quantized_tensors( if no_quantization: for i in range(self.num_tensors): # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.tensor_shapes[i] # Get tensor data slice if self.offsets is not None: @@ -654,7 +663,7 @@ def split_into_quantized_tensors( end_offset = start_offset + numel if self.has_data(): - tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) result.append(tensor_data) elif self.has_columnwise_data(): tensor_data = self.columnwise_data[start_offset:end_offset].view( @@ -670,7 +679,7 @@ def split_into_quantized_tensors( end_offset = start_offset + numel if self.has_data(): - tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + tensor_data = self.rowwise_data[start_offset:end_offset].view(tensor_shape) result.append(tensor_data) elif self.has_columnwise_data(): tensor_data = self.columnwise_data[start_offset:end_offset].view( @@ -683,7 +692,11 @@ def split_into_quantized_tensors( return result # Case 2: Quantized tensors - recipe = self.quantizer._get_compatible_recipe() + if len(self.quantizers) != self.num_tensors: + raise RuntimeError( + f"Expected {self.num_tensors} quantizers, got {len(self.quantizers)}" + ) + recipe = self.quantizers[0]._get_compatible_recipe() # populate scale_inv_offsets from the tensor offsets if self.scale_inv is not None and self.scale_inv_offsets is None: @@ -698,8 +711,9 @@ def split_into_quantized_tensors( self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 for i in range(self.num_tensors): + quantizer = self.quantizers[i] # Get tensor shape - tensor_shape = self.shape[i] + tensor_shape = self.tensor_shapes[i] numel = tensor_shape[0] * tensor_shape[1] # Get data offsets @@ -712,7 +726,7 @@ def split_into_quantized_tensors( data_end = data_start + numel # Special shape handling for NVFP4. - nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() + nvfp4 = quantizer._get_compatible_recipe().nvfp4() if nvfp4: data_start = data_start // 2 data_end = data_end // 2 @@ -723,15 +737,15 @@ def split_into_quantized_tensors( if self.has_data(): if nvfp4: - rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) + rowwise_tensor_shape = quantizer.convert_shape_for_fp4(tensor_shape) else: rowwise_tensor_shape = tensor_shape - rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + rowwise_data = self.rowwise_data[data_start:data_end].view(rowwise_tensor_shape) if self.has_columnwise_data(): - columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) + columnwise_tensor_shape = quantizer.get_columnwise_shape(tensor_shape) if nvfp4: - columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( + columnwise_tensor_shape = quantizer.convert_shape_for_fp4( columnwise_tensor_shape ) columnwise_data = self.columnwise_data[data_start:data_end].view( @@ -750,7 +764,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Calculate expected scale shape for MXFP8 - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -761,25 +775,25 @@ def split_into_quantized_tensors( # for paged stashing, columnwise_scale_inv should depend on the split offsets cscale_end = self.columnwise_scale_inv_offsets[i + 1] - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) - if self.quantizer.internal: + if quantizer.internal: mxfp8_tensor_class = MXFP8TensorStorage else: mxfp8_tensor_class = MXFP8Tensor tensor = mxfp8_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, - with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, ) result.append(tensor) @@ -790,18 +804,18 @@ def split_into_quantized_tensors( if self.scale_inv is not None: scale_inv = self.scale_inv[i : i + 1] - if self.quantizer.internal: + if quantizer.internal: float8_tensor_class = Float8TensorStorage else: float8_tensor_class = Float8Tensor tensor = float8_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, data=rowwise_data, fp8_scale_inv=scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, data_transpose=columnwise_data, ) result.append(tensor) @@ -818,7 +832,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -830,28 +844,28 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) # Compute is_2D_scaled and data_format from quantizer attributes - is_2D_scaled = self.quantizer.block_scaling_dim == 2 + is_2D_scaled = quantizer.block_scaling_dim == 2 - if self.quantizer.internal: + if quantizer.internal: float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage else: float8_blockwise_q_tensor_class = Float8BlockwiseQTensor tensor = float8_blockwise_q_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, - fp8_dtype=self.quantizer.dtype, - quantizer=self.quantizer, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, is_2D_scaled=is_2D_scaled, ) result.append(tensor) @@ -870,7 +884,7 @@ def split_into_quantized_tensors( scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer - scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + scale_shape = quantizer.get_scale_shape(tensor_shape, False) rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) if ( @@ -882,7 +896,7 @@ def split_into_quantized_tensors( cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer - cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cscale_shape = quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( cscale_shape ) @@ -894,23 +908,23 @@ def split_into_quantized_tensors( if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] - if self.quantizer.internal: + if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage else: nvfp4_tensor_class = NVFP4Tensor tensor = nvfp4_tensor_class( shape=tensor_shape, - dtype=self.dtype, + dtype=self.fake_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, amax_rowwise=amax_rowwise, amax_columnwise=amax_columnwise, - fp4_dtype=self.quantizer.dtype, - quantizer=self.quantizer, - with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + fp4_dtype=quantizer.dtype, + quantizer=quantizer, + with_gemm_swizzled_scales=quantizer.optimize_for_gemm, ) result.append(tensor) @@ -919,32 +933,6 @@ def split_into_quantized_tensors( return result - @staticmethod - def create_and_quantize( - tensors: int, - quantizer: None | Quantizer, - *, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - noop_flag: Optional[torch.Tensor] = None, - ) -> Tuple[QuantizedTensorStorage, ...]: - """ - Quantize given tensors into quantized tensors with underlying - storage allocated in a GroupedTensor. - """ - - grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=len(tensors), - shape=[t.shape for t in tensors], - quantizer=quantizer, - device=device, - dtype=dtype, - ) - - grouped_tensor.quantize(tensors, noop_flag=noop_flag) - - return grouped_tensor - def quantize( self, tensors: List[torch.Tensor], @@ -956,5 +944,7 @@ def quantize( quantized_tensors = self.split_into_quantized_tensors() for i in range(self.num_tensors): - self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + self.quantizers[i].update_quantized( + tensors[i], quantized_tensors[i], noop_flag=noop_flag + ) return quantized_tensors