|
| 1 | +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | +"""Tests for GroupedTensor class""" |
| 6 | + |
| 7 | +from typing import List, Tuple |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | +import transformer_engine.pytorch as te |
| 11 | +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor |
| 12 | +from transformer_engine.pytorch import ( |
| 13 | + Quantizer, |
| 14 | + Float8Quantizer, |
| 15 | + Float8CurrentScalingQuantizer, |
| 16 | + Float8BlockQuantizer, |
| 17 | + MXFP8Quantizer, |
| 18 | + NVFP4Quantizer, |
| 19 | +) |
| 20 | +from transformer_engine.pytorch.constants import TE_DType_To_Torch |
| 21 | +import transformer_engine_torch as tex |
| 22 | + |
| 23 | +# Check available recipes |
| 24 | +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) |
| 25 | +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( |
| 26 | + return_reason=True |
| 27 | +) |
| 28 | +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) |
| 29 | +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) |
| 30 | + |
| 31 | +_quantization_params = [ |
| 32 | + pytest.param( |
| 33 | + "fp8_delayed_scaling", |
| 34 | + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), |
| 35 | + ), |
| 36 | + pytest.param( |
| 37 | + "fp8_current_scaling", |
| 38 | + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), |
| 39 | + ), |
| 40 | + pytest.param( |
| 41 | + "fp8_blockwise", |
| 42 | + marks=pytest.mark.skipif( |
| 43 | + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling |
| 44 | + ), |
| 45 | + ), |
| 46 | + pytest.param( |
| 47 | + "mxfp8", |
| 48 | + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), |
| 49 | + ), |
| 50 | + pytest.param( |
| 51 | + "nvfp4", |
| 52 | + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), |
| 53 | + ), |
| 54 | +] |
| 55 | + |
| 56 | + |
| 57 | +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: |
| 58 | + """Create quantizers for given quantization scheme""" |
| 59 | + |
| 60 | + if quantization == "fp8_delayed_scaling": |
| 61 | + quantizer = Float8Quantizer( |
| 62 | + scale=torch.ones(1, dtype=torch.float32, device="cuda"), |
| 63 | + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), |
| 64 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 65 | + ) |
| 66 | + elif quantization == "fp8_current_scaling": |
| 67 | + quantizer = Float8CurrentScalingQuantizer( |
| 68 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 69 | + device="cuda", |
| 70 | + ) |
| 71 | + quantizer.set_usage(rowwise=True, columnwise=False) |
| 72 | + elif quantization == "fp8_blockwise": |
| 73 | + quantizer = Float8BlockQuantizer( |
| 74 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 75 | + rowwise=True, |
| 76 | + columnwise=False, |
| 77 | + force_pow_2_scales=True, |
| 78 | + amax_epsilon=0.0, |
| 79 | + block_scaling_dim=1, |
| 80 | + ) |
| 81 | + elif quantization == "mxfp8": |
| 82 | + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) |
| 83 | + elif quantization == "nvfp4": |
| 84 | + quantizer = NVFP4Quantizer( |
| 85 | + with_rht=False, |
| 86 | + with_post_rht_amax=False, |
| 87 | + with_2d_quantization=False, |
| 88 | + stochastic_rounding=False, |
| 89 | + with_random_sign_mask=False, |
| 90 | + ) |
| 91 | + else: |
| 92 | + raise ValueError(f"Unknown quantization scheme: {quantization}") |
| 93 | + |
| 94 | + quantizer.internal = False |
| 95 | + |
| 96 | + return quantizer |
| 97 | + |
| 98 | + |
| 99 | +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: |
| 100 | + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): |
| 101 | + return qtensor._data |
| 102 | + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): |
| 103 | + return qtensor._rowwise_data |
| 104 | + raise ValueError(f"Unknown quantization scheme: {quantization}") |
| 105 | + |
| 106 | + |
| 107 | +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: |
| 108 | + if quantization == "nvfp4": |
| 109 | + return numel // 2 |
| 110 | + return numel |
| 111 | + |
| 112 | + |
| 113 | +class TestGroupedTensor: |
| 114 | + @staticmethod |
| 115 | + def setup_class(cls) -> None: |
| 116 | + # Configure RNG |
| 117 | + seed = 1234 |
| 118 | + torch.manual_seed(seed) |
| 119 | + torch.cuda.manual_seed(seed) |
| 120 | + |
| 121 | + def test_basic_construction_all_same_shape(self) -> None: |
| 122 | + """Test GroupedTensor construction with all tensors having same shape""" |
| 123 | + num_tensors = 4 |
| 124 | + shape = [(256, 512) for _ in range(num_tensors)] |
| 125 | + |
| 126 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 127 | + num_tensors=num_tensors, |
| 128 | + shape=shape, |
| 129 | + quantizer=None, |
| 130 | + device="cuda", |
| 131 | + dtype=torch.float32, |
| 132 | + ) |
| 133 | + |
| 134 | + assert grouped_tensor.num_tensors == num_tensors |
| 135 | + assert grouped_tensor.all_same_shape() |
| 136 | + assert grouped_tensor.all_same_first_dim() |
| 137 | + assert grouped_tensor.all_same_last_dim() |
| 138 | + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) |
| 139 | + assert grouped_tensor.get_common_first_dim() == 256 |
| 140 | + assert grouped_tensor.get_common_last_dim() == 512 |
| 141 | + assert grouped_tensor.has_data() |
| 142 | + |
| 143 | + def test_basic_construction_varying_first_dim(self) -> None: |
| 144 | + """Test GroupedTensor construction with varying first dimension""" |
| 145 | + num_tensors = 3 |
| 146 | + shape = [(128, 512), (256, 512), (384, 512)] |
| 147 | + |
| 148 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 149 | + num_tensors=num_tensors, |
| 150 | + shape=shape, |
| 151 | + quantizer=None, |
| 152 | + device="cuda", |
| 153 | + dtype=torch.float32, |
| 154 | + ) |
| 155 | + |
| 156 | + assert grouped_tensor.num_tensors == num_tensors |
| 157 | + assert not grouped_tensor.all_same_shape() |
| 158 | + assert not grouped_tensor.all_same_first_dim() |
| 159 | + assert grouped_tensor.all_same_last_dim() |
| 160 | + assert grouped_tensor.get_common_last_dim() == shape[0][1] |
| 161 | + assert grouped_tensor.logical_shape == ( |
| 162 | + sum(v for v, _ in shape), |
| 163 | + shape[0][1], |
| 164 | + ) # sum of first dims |
| 165 | + |
| 166 | + def test_split_into_quantized_tensors_no_quantization(self) -> None: |
| 167 | + """Test split_into_quantized_tensors for unquantized tensors""" |
| 168 | + num_tensors = 3 |
| 169 | + shape = [(256, 512) for _ in range(num_tensors)] |
| 170 | + |
| 171 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 172 | + num_tensors=num_tensors, |
| 173 | + shape=shape, |
| 174 | + quantizer=None, |
| 175 | + device="cuda", |
| 176 | + dtype=torch.float32, |
| 177 | + ) |
| 178 | + |
| 179 | + # Get the original data pointer |
| 180 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 181 | + |
| 182 | + # Split into tensors |
| 183 | + tensors = grouped_tensor.split_into_quantized_tensors() |
| 184 | + |
| 185 | + assert len(tensors) == num_tensors |
| 186 | + |
| 187 | + # Verify each tensor has correct shape and shares storage |
| 188 | + for i, tensor in enumerate(tensors): |
| 189 | + assert tensor.shape == shape[i] |
| 190 | + assert isinstance(tensor, torch.Tensor) |
| 191 | + assert not hasattr(tensor, "_data") # Not a quantized tensor |
| 192 | + |
| 193 | + # Verify data pointer is within the original grouped tensor storage |
| 194 | + # The tensor should be a view of the original data |
| 195 | + assert tensor.data_ptr() >= original_data_ptr |
| 196 | + |
| 197 | + # Calculate expected offset |
| 198 | + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() |
| 199 | + assert tensor.data_ptr() == original_data_ptr + expected_offset |
| 200 | + |
| 201 | + @pytest.mark.parametrize("quantization", _quantization_params) |
| 202 | + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: |
| 203 | + """Test split_into_quantized_tensors for quantized tensors""" |
| 204 | + num_tensors = 3 |
| 205 | + shape = [(512, 512) for _ in range(num_tensors)] |
| 206 | + quantizers = make_quantizer(quantization, num_tensors, shape) |
| 207 | + |
| 208 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 209 | + num_tensors=num_tensors, |
| 210 | + shape=shape, |
| 211 | + quantizer=quantizers, |
| 212 | + device="cuda", |
| 213 | + ) |
| 214 | + |
| 215 | + # Get the original data pointer |
| 216 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 217 | + |
| 218 | + # Split into tensors |
| 219 | + tensors = grouped_tensor.split_into_quantized_tensors() |
| 220 | + |
| 221 | + assert len(tensors) == num_tensors |
| 222 | + |
| 223 | + # Verify each tensor shares storage with the grouped tensor |
| 224 | + for i, tensor in enumerate(tensors): |
| 225 | + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) |
| 226 | + assert rowwise_data is not None |
| 227 | + assert rowwise_data.data_ptr() >= original_data_ptr |
| 228 | + numel = shape[i][0] * shape[i][1] |
| 229 | + expected_offset = _rowwise_offset_bytes(i * numel, quantization) |
| 230 | + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset |
| 231 | + |
| 232 | + def test_split_varying_shapes(self) -> None: |
| 233 | + """Test split_into_quantized_tensors with varying shapes""" |
| 234 | + num_tensors = 3 |
| 235 | + shape = [(128, 512), (256, 512), (384, 512)] |
| 236 | + |
| 237 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 238 | + num_tensors=num_tensors, |
| 239 | + shape=shape, |
| 240 | + quantizer=None, |
| 241 | + device="cuda", |
| 242 | + dtype=torch.float32, |
| 243 | + ) |
| 244 | + |
| 245 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 246 | + tensors = grouped_tensor.split_into_quantized_tensors() |
| 247 | + |
| 248 | + assert len(tensors) == num_tensors |
| 249 | + |
| 250 | + # Verify shapes and storage |
| 251 | + cumulative_offset = 0 |
| 252 | + for i, tensor in enumerate(tensors): |
| 253 | + assert tensor.shape == shape[i] |
| 254 | + expected_offset = cumulative_offset * tensor.element_size() |
| 255 | + assert tensor.data_ptr() == original_data_ptr + expected_offset |
| 256 | + cumulative_offset += shape[i][0] * shape[i][1] |
| 257 | + |
| 258 | + @pytest.mark.parametrize("quantization", _quantization_params) |
| 259 | + def test_quantize_inplace(self, quantization: str) -> None: |
| 260 | + """Test that quantize is done in-place for all recipes""" |
| 261 | + num_tensors = 3 |
| 262 | + shape = [(512, 512) for _ in range(num_tensors)] |
| 263 | + quantizers = make_quantizer(quantization, num_tensors, shape) |
| 264 | + |
| 265 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 266 | + num_tensors=num_tensors, |
| 267 | + shape=shape, |
| 268 | + quantizer=quantizers, |
| 269 | + device="cuda", |
| 270 | + ) |
| 271 | + |
| 272 | + # Get original data pointers before quantization |
| 273 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 274 | + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() |
| 275 | + original_scale_ptr = ( |
| 276 | + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None |
| 277 | + ) |
| 278 | + |
| 279 | + # Create input tensors |
| 280 | + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] |
| 281 | + |
| 282 | + # Quantize in place |
| 283 | + quantized_tensors = grouped_tensor.quantize(input_tensors) |
| 284 | + |
| 285 | + # Verify data pointers haven't changed (in-place operation) |
| 286 | + assert grouped_tensor.data.data_ptr() == original_data_ptr |
| 287 | + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr |
| 288 | + if original_scale_ptr is not None: |
| 289 | + assert grouped_tensor.scale.data_ptr() == original_scale_ptr |
| 290 | + |
| 291 | + # Verify returned tensors point to the same storage |
| 292 | + for i, qtensor in enumerate(quantized_tensors): |
| 293 | + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) |
| 294 | + numel = shape[i][0] * shape[i][1] |
| 295 | + expected_offset = _rowwise_offset_bytes(i * numel, quantization) |
| 296 | + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset |
| 297 | + |
| 298 | + @pytest.mark.parametrize("quantization", _quantization_params) |
| 299 | + def test_quantize_varying_shapes(self, quantization: str) -> None: |
| 300 | + """Test quantize with varying shapes""" |
| 301 | + num_tensors = 3 |
| 302 | + shape = [(256, 512), (512, 512), (768, 512)] |
| 303 | + quantizers = make_quantizer(quantization, num_tensors, shape) |
| 304 | + |
| 305 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 306 | + num_tensors=num_tensors, |
| 307 | + shape=shape, |
| 308 | + quantizer=quantizers, |
| 309 | + device="cuda", |
| 310 | + ) |
| 311 | + |
| 312 | + # Get original data pointers |
| 313 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 314 | + |
| 315 | + # Create input tensors with varying shapes |
| 316 | + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] |
| 317 | + |
| 318 | + # Quantize in place |
| 319 | + quantized_tensors = grouped_tensor.quantize(input_tensors) |
| 320 | + |
| 321 | + # Verify data pointer hasn't changed |
| 322 | + assert grouped_tensor.data.data_ptr() == original_data_ptr |
| 323 | + |
| 324 | + # Verify each tensor points to correct location |
| 325 | + cumulative_numel = 0 |
| 326 | + for qtensor, tensor_shape in zip(quantized_tensors, shape): |
| 327 | + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) |
| 328 | + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) |
| 329 | + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset |
| 330 | + cumulative_numel += tensor_shape[0] * tensor_shape[1] |
| 331 | + |
| 332 | + @pytest.mark.parametrize("quantization", _quantization_params) |
| 333 | + def test_static_quantize_method(self, quantization: str) -> None: |
| 334 | + """Test the static quantize method""" |
| 335 | + num_tensors = 3 |
| 336 | + shape = [(512, 512) for _ in range(num_tensors)] |
| 337 | + quantizers = make_quantizer(quantization, num_tensors, shape) |
| 338 | + |
| 339 | + # Create input tensors |
| 340 | + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] |
| 341 | + |
| 342 | + # Use static quantize method |
| 343 | + grouped_tensor = GroupedTensor.create_and_quantize( |
| 344 | + tensors=input_tensors, |
| 345 | + quantizer=quantizers, |
| 346 | + device="cuda", |
| 347 | + ) |
| 348 | + |
| 349 | + # Verify the grouped tensor was created correctly |
| 350 | + assert grouped_tensor.num_tensors == num_tensors |
| 351 | + assert grouped_tensor.has_data() |
| 352 | + |
| 353 | + # Verify quantized_tensors were created and point to same storage |
| 354 | + assert grouped_tensor.quantized_tensors is not None |
| 355 | + assert len(grouped_tensor.quantized_tensors) == num_tensors |
| 356 | + |
| 357 | + original_data_ptr = grouped_tensor.data.data_ptr() |
| 358 | + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): |
| 359 | + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) |
| 360 | + numel = shape[i][0] * shape[i][1] |
| 361 | + expected_offset = _rowwise_offset_bytes(i * numel, quantization) |
| 362 | + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset |
| 363 | + |
| 364 | + def test_clear(self) -> None: |
| 365 | + """Test clear method""" |
| 366 | + num_tensors = 3 |
| 367 | + shape = [(256, 512) for _ in range(num_tensors)] |
| 368 | + |
| 369 | + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( |
| 370 | + num_tensors=num_tensors, |
| 371 | + shape=shape, |
| 372 | + quantizer=None, |
| 373 | + device="cuda", |
| 374 | + dtype=torch.float32, |
| 375 | + ) |
| 376 | + |
| 377 | + assert grouped_tensor.has_data() |
| 378 | + assert grouped_tensor.num_tensors == num_tensors |
| 379 | + |
| 380 | + grouped_tensor.clear() |
| 381 | + |
| 382 | + assert not grouped_tensor.has_data() |
| 383 | + assert grouped_tensor.num_tensors == 0 |
| 384 | + assert grouped_tensor.data is None |
| 385 | + assert grouped_tensor.logical_shape == (0, 0) |
0 commit comments