From 27196ecdc4e94a1de28897069f61ce682b015a96 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 26 Feb 2026 14:47:31 -0800 Subject: [PATCH] Enable dequantization from just columnwise data Signed-off-by: Przemek Tredak --- tests/pytorch/test_quantized_tensor.py | 83 +++++++++++++++++++ .../tensor/storage/mxfp8_tensor_storage.py | 4 +- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index b2e8fca7cb..978ec09b40 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -656,3 +656,86 @@ def test_chunk( tols = dict(rtol=0, atol=0) # Chunking is exact y_test = y_test.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) + + +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +class TestMXFP8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]]) + def test_mxfp8_dequantize_columnwise_only( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + dims: DimsType, + ) -> None: + """Check dequantization of MXFP8 tensor with only columnwise data""" + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 + + # Quantize with both rowwise and columnwise + quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + x_mxfp8 = quantizer(x_ref) + + # Dequantize from rowwise (default path) + x_deq_rowwise = x_mxfp8.dequantize(dtype=dtype) + + # Rowwise dequantization should be close to the original + torch.testing.assert_close(x_deq_rowwise, x_ref, **_tols[fp8_dtype]) + + # Strip rowwise data, keeping only columnwise + x_mxfp8.update_usage(rowwise_usage=False, columnwise_usage=True) + assert x_mxfp8._rowwise_data is None + assert x_mxfp8._columnwise_data is not None + + # Dequantize from columnwise only + x_deq_columnwise = x_mxfp8.dequantize(dtype=dtype) + + # Columnwise dequantization should be close to the original + torch.testing.assert_close(x_deq_columnwise, x_ref, **_tols[fp8_dtype]) + + # Rowwise and columnwise dequantizations should match each other + torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dims", [[128, 128], [256, 256]]) + def test_mxfp8_dequantize_columnwise_only_quantized_separately( + self, + fp8_dtype: tex.DType, + dims: DimsType, + ) -> None: + """Check dequantization of MXFP8 tensor quantized with columnwise only""" + + dtype = torch.bfloat16 + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1 + + # Quantize with columnwise only (no rowwise) + quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True) + x_mxfp8 = quantizer(x_ref) + assert x_mxfp8._rowwise_data is None + assert x_mxfp8._columnwise_data is not None + + # Dequantize from columnwise only + x_deq = x_mxfp8.dequantize(dtype=dtype) + + # Should be close to the original + torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype]) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 5c8510488f..e3decce4c2 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -33,9 +33,9 @@ def forward( dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format - if tensor._rowwise_data is not None: + if tensor._rowwise_data is not None or tensor._columnwise_data is not None: return tex.dequantize(tensor, dtype) - raise NotImplementedError("Casting back from the transpose not implemented yet!") + raise ValueError("Cannot dequantize MXFP8 tensor with no data") @staticmethod def backward(