From 066a56440f763f118328785c99ed9b24327d13a0 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 6 Jan 2026 15:11:58 +0200 Subject: [PATCH 01/11] Add support for MXFP8 PTQ Signed-off-by: Daniel Serebrenik --- examples/llm_ptq/hf_ptq.py | 3 + modelopt/torch/export/model_config.py | 1 + modelopt/torch/export/quant_utils.py | 72 ++++ modelopt/torch/export/unified_export_hf.py | 36 ++ .../nn/modules/tensor_quantizer.py | 6 + .../torch/quantization/qtensor/__init__.py | 1 + .../quantization/qtensor/mxfp8_tensor.py | 277 ++++++++++++++ .../torch/quantization/test_qtensor_cuda.py | 337 +++++++++++++++++- 8 files changed, 732 insertions(+), 1 deletion(-) create mode 100644 modelopt/torch/quantization/qtensor/mxfp8_tensor.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..1dba72324 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -83,6 +83,7 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "mxfp8": mtq.MXFP8_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { @@ -184,6 +185,7 @@ def auto_quantize( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -766,6 +768,7 @@ def quantize_main( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 306348f2c..9553b4fcf 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -35,6 +35,7 @@ QUANTIZATION_NVFP4 = "nvfp4" QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8" QUANTIZATION_MXFP4 = "mxfp4" +QUANTIZATION_MXFP8 = "mxfp8" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" QUANTIZATION_NVFP4_AWQ = "nvfp4_awq" QUANTIZATION_FP8_PB_REAL = "fp8_pb_real" diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..21459b69a 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -16,6 +16,7 @@ """Utils for quantization including scaling factors adjustments.""" import logging +import math from collections.abc import Generator from types import SimpleNamespace from typing import Any @@ -30,6 +31,7 @@ from modelopt.torch.quantization.qtensor import ( FP8QTensor, MXFP4QTensor, + MXFP8QTensor, NVFP4QTensor, QTensorWrapper, ) @@ -54,6 +56,7 @@ QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO, QUANTIZATION_MXFP4, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -290,6 +293,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[ 1 ].reshape(*weight.shape[:-1], -1) + + if quantization_format == QUANTIZATION_MXFP8: + return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer) return get_scaling_factor(weight_quantizer) @@ -474,6 +480,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if weight_quantizer.num_bits == (4, 3): if weight_quantizer.block_sizes: assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer" + # Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0)) + block_sizes = getattr(weight_quantizer, "block_sizes") + if ( + isinstance(block_sizes, dict) + and block_sizes.get("type", "static") == "dynamic" + and block_sizes.get("scale_bits") == (8, 0) + ): + return QUANTIZATION_MXFP8 if weight_quantizer.fake_quant: return QUANTIZATION_FP8_PB_WO else: @@ -669,6 +683,11 @@ def process_layer_quant_config(layer_config_dict): "quant_algo": "W4A8_MXFP4_FP8", "group_size": block_size_value, } + elif v == "mxfp8": + layer_config = { + "quant_algo": "MXFP8", + "group_size": block_size_value, + } else: layer_config = {"quant_algo": v} @@ -738,6 +757,56 @@ def pack_int4_in_uint8(weight, weights_scaling_factor): return packed_byte.T.contiguous().view(torch.uint8) +def _quantize_weight_mxfp8( + weight: torch.Tensor, + weights_scaling_factor: torch.Tensor, +) -> torch.Tensor: + """Quantize weight tensor using MXFP8 format. + + MXFP8 uses dynamic block quantization with FP8 (E4M3) along dimension -1 only (1D blocking). + Scales are E8M0 format (power-of-2 only), stored as biased uint8 exponents. + """ + assert weights_scaling_factor is not None, ( + "weights_scaling_factor must be provided for MXFP8 quantization." + ) + + # MXFP8 block size is 32 + block_size = 32 + maxbound = torch.finfo(torch.float8_e4m3fn).max # 448.0 + + out_dim, in_dim = weight.shape[-2], weight.shape[-1] + expected_shape = (out_dim, in_dim // block_size) + + # Reshape scaling factor if needed (same number of elements but wrong shape) + if weights_scaling_factor.shape != expected_shape: + if weights_scaling_factor.numel() == math.prod(expected_shape): + weights_scaling_factor = weights_scaling_factor.reshape(expected_shape) + + # Handle E8M0 uint8 scale format (biased exponent) + if weights_scaling_factor.dtype == torch.uint8: + # E8M0 format: descale = 2^(exponent - 127), scale = 2^(127 - exponent) + scale_factor = torch.exp2(127 - weights_scaling_factor.float()) + else: + # Legacy float32 scale format: scale = amax / maxbound + # Convert to E8M0: exponent = ceil(log2(scale)) + e8m0_exponent = torch.ceil(torch.log2(weights_scaling_factor.clamp(min=2**-127))) + e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) + scale_factor = torch.exp2(-e8m0_exponent) + + # Reshape weight to [out_dim, num_blocks, block_size] + num_blocks = in_dim // block_size + weight_reshaped = weight.view(out_dim, num_blocks, block_size) + + # Apply scale and quantize to FP8 E4M3 + scale_factor_expanded = scale_factor.unsqueeze(-1) + scaled_weight = weight_reshaped * scale_factor_expanded + scaled_weight = torch.clamp(scaled_weight, min=-maxbound, max=maxbound) + quantized_weight = scaled_weight.to(torch.float8_e4m3fn) + + # Reshape back to original 2D shape + return quantized_weight.view(out_dim, in_dim) + + def to_quantized_weight( weight: torch.Tensor, weights_scaling_factor: torch.Tensor, @@ -773,6 +842,9 @@ def to_quantized_weight( if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) + if quantization == QUANTIZATION_MXFP8: + return _quantize_weight_mxfp8(weight, weights_scaling_factor) + if quantization == QUANTIZATION_FP8_PB_WO: return FP8QTensor.quantize( weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size} diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 1dd1c1822..96cb7b16a 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -51,6 +51,7 @@ QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PC_PT, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, @@ -297,6 +298,41 @@ def _export_quantized_weight( weight_quantizer._scale.to(torch.float32), ) del weight_quantizer._scale + elif quantization_format == QUANTIZATION_MXFP8: + # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + # If _scale is already uint8 (from MXFP8QTensor.quantize), keep it as is + if weight_quantizer._scale.dtype == torch.uint8: + sub_module.register_buffer( + quantizer_attrs.weight_scale, + weight_quantizer._scale, + ) + else: + # Legacy path: convert float32 scale to E8M0 uint8 + # scale = amax / E4M3_max, so exponent = ceil(log2(scale)) + scale = weight_quantizer._scale.to(torch.float32) + e8m0_exponent = torch.ceil(torch.log2(scale.clamp(min=2**-127))) + e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) + e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) + sub_module.register_buffer( + quantizer_attrs.weight_scale, + e8m0_scale, + ) + del weight_quantizer._scale + else: + # Compute E8M0 scaling factor from weight tensor + weight = getattr(sub_module, weight_name) + block_size = weight_quantizer.block_sizes[-1] + out_dim, in_dim = weight.shape[-2], weight.shape[-1] + num_blocks = in_dim // block_size + weight_reshaped = weight.view(out_dim, num_blocks, block_size) + amax = weight_reshaped.float().abs().max(dim=-1)[0] + maxbound = torch.finfo(torch.float8_e4m3fn).max + descale = amax / maxbound + e8m0_exponent = torch.ceil(torch.log2(descale.clamp(min=2**-127))) + e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) + e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) + sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) else: sub_module.register_buffer( quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 7d3fa1251..bc09c4ddf 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -49,6 +49,7 @@ INT4QTensor, INT8QTensor, MXFP4QTensor, + MXFP8QTensor, NF4QTensor, NVFP4QTensor, QTensorWrapper, @@ -689,8 +690,13 @@ def _real_quantize(self, inputs): ): # MX quantization if self._num_bits == (2, 1): + # MXFP4 outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) buffer_to_register["_scale"] = scales + elif self._num_bits == (4, 3): + # MXFP8 + outputs, scales = MXFP8QTensor.quantize(inputs, self._block_sizes[-1]) + buffer_to_register["_scale"] = scales else: raise ValueError( f"Real quantization for MX {self._num_bits} format is not supported." diff --git a/modelopt/torch/quantization/qtensor/__init__.py b/modelopt/torch/quantization/qtensor/__init__.py index c4ed88f87..9c623c1bd 100644 --- a/modelopt/torch/quantization/qtensor/__init__.py +++ b/modelopt/torch/quantization/qtensor/__init__.py @@ -20,5 +20,6 @@ from .int4_tensor import * from .int8_tensor import * from .mxfp4_tensor import * +from .mxfp8_tensor import * from .nf4_tensor import * from .nvfp4_tensor import * diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py new file mode 100644 index 000000000..23e50220e --- /dev/null +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements MXFP8 quantization for efficient tensor storage and computation.""" + +import torch + +from ..qtensor.base_qtensor import BaseQuantizedTensor + +__all__ = ["MXFP8QTensor"] + + +class MXFP8QTensor(BaseQuantizedTensor): + """Implements the MXFP8 quantization on tensors for more efficient storage or computation. + + MXFP8 uses: + - FP8 E4M3 format for elements + - E8M0 format for shared scales (power-of-2 only, stored as biased uint8 exponent) + - Block size of 32 elements along the last dimension + + Attributes: + quantized_data (torch.Tensor): The quantized data stored as float8_e4m3fn tensor. + """ + + E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + BLOCK_SIZE = 32 + + @classmethod + def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: + """Compute E8M0 exponent from per-block amax values. + + Args: + amax: Per-block absolute max values. + + Returns: + torch.Tensor: Float tensor of E8M0 exponents (unbiased, range [-127, 127]). + """ + # Compute E8M0 scale: scale = 2^ceil(log2(amax / E4M3_max)) + descale = amax / cls.E4M3_MAX + + # Handle zero/inf/nan cases + min_value = torch.tensor(-127.0, device=descale.device) + log2_descale = torch.where( + descale > 0, + torch.log2(descale), + min_value, + ) + + # Ceil to get power-of-2 scale + e8m0_exponent = torch.ceil(log2_descale) + + # Clamp exponent to valid E8M0 range [-127, 127] + return torch.clamp(e8m0_exponent, min=-127, max=127) + + @classmethod + def get_weights_scaling_factor( + cls, + weight: torch.Tensor, + block_size: int | None = None, + ) -> torch.Tensor: + """Returns E8M0 scale (uint8 biased exponent) for weight tensor. + + Args: + weight: The weight tensor to compute scale for. Must be 2D. + block_size: The block size for quantization. Defaults to 32. + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // block_size]. + """ + if block_size is None: + block_size = cls.BLOCK_SIZE + + assert block_size > 0, f"block_size must be positive, got {block_size}" + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + out_dim, in_dim = weight.shape[-2], weight.shape[-1] + + assert in_dim % block_size == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by block_size ({block_size})" + ) + + # Reshape to [out_dim, num_blocks, block_size] + weight_reshaped = weight.view(out_dim, in_dim // block_size, block_size) + + # Compute amax per block + amax = weight_reshaped.float().abs().max(dim=-1)[0] + + # Compute E8M0 exponent and convert to biased uint8 (bias = 127) + e8m0_exponent = cls._compute_e8m0_exponent(amax) + return (e8m0_exponent + 127).to(torch.uint8) + + @classmethod + def get_weights_scaling_factor_from_quantizer( + cls, + weight: torch.Tensor, + weight_quantizer, + ) -> torch.Tensor: + """Returns E8M0 scale from quantizer or computes from weight. + + This method handles extracting the scale from a weight quantizer, + with proper format conversion and shape correction. + + Args: + weight: The weight tensor. + weight_quantizer: The weight quantizer with block_sizes and optional _scale. + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // block_size]. + """ + assert hasattr(weight_quantizer, "block_sizes"), ( + "weight_quantizer must have 'block_sizes' attribute" + ) + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + block_size = weight_quantizer.block_sizes[-1] + out_dim, in_dim = weight.shape[-2], weight.shape[-1] + expected_shape = (out_dim, in_dim // block_size) + + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + scale = weight_quantizer._scale + + # If already uint8 E8M0 format, return as-is (with shape correction if needed) + if scale.dtype == torch.uint8: + if ( + scale.shape != expected_shape + and scale.numel() == expected_shape[0] * expected_shape[1] + ): + scale = scale.reshape(expected_shape) + return scale + + # Legacy float32 scale - convert to E8M0 uint8 + if scale.shape != expected_shape: + if scale.numel() == expected_shape[0] * expected_shape[1]: + scale = scale.reshape(expected_shape) + else: + # Shape mismatch, recompute from weight + return cls.get_weights_scaling_factor(weight, block_size) + + # Convert float32 scale to E8M0 uint8 + e8m0_exponent = torch.ceil(torch.log2(scale.clamp(min=2**-127))) + e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) + return (e8m0_exponent + 127).to(torch.uint8) + + # No scale in quantizer, compute from weight + return cls.get_weights_scaling_factor(weight, block_size) + + @classmethod + def quantize(cls, input: torch.Tensor, block_size: int | None = None) -> tuple: + """Convert a tensor to MXFP8 quantized format. + + Args: + input (torch.Tensor): The input tensor to be quantized. + block_size (int | None): The block size for quantization. Defaults to 32. + + Returns: + tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent. + """ + if block_size is None: + block_size = cls.BLOCK_SIZE + + assert block_size > 0, f"block_size must be positive, got {block_size}" + assert input.numel() > 0, "Input tensor must not be empty" + assert input.dim() >= 1, f"Input must have at least 1 dimension, got {input.dim()}D" + + original_shape = input.shape + original_dtype = input.dtype + + # Pad last dimension if not divisible by block_size + last_dim = original_shape[-1] + if last_dim % block_size != 0: + pad_size = block_size - (last_dim % block_size) + input = torch.nn.functional.pad(input, (0, pad_size), mode="constant", value=0) + + # Flatten to [num_blocks, block_size] for block-wise quantization + input_flat = input.view(-1, block_size) + + # Compute amax per block + input_amax = input_flat.float().abs().max(dim=-1, keepdim=True).values + + # Compute E8M0 exponent and scale factor + e8m0_exponent = cls._compute_e8m0_exponent(input_amax) + scale_factor = torch.exp2(-e8m0_exponent) + + # Apply scale and quantize to FP8 E4M3 + scaled_input = input_flat * scale_factor + + # Clamp to E4M3 range [-448, 448] and convert + scaled_input = torch.clamp(scaled_input, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) + quantized_data = scaled_input.to(torch.float8_e4m3fn) + + # Reshape: account for padded shape, then crop back to original + padded_shape = list(original_shape) + padded_shape[-1] = input.shape[-1] + quantized_data = quantized_data.view(padded_shape) + quantized_data = quantized_data[..., :last_dim] + + # Convert exponent to biased uint8 (bias = 127) + e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) + + # Reshape scale to preserve leading dimensions: (*original_shape[:-1], padded_last_dim // block_size) + scale_shape = [*original_shape[:-1], input.shape[-1] // block_size] + e8m0_scale = e8m0_scale.view(scale_shape) + + return cls(original_shape, original_dtype, quantized_data), e8m0_scale + + def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: + """Dequantize MXFP8 tensor back to the target dtype. + + Args: + dtype (torch.dtype | None): Target dtype for dequantization. Defaults to original dtype. + **kwargs: Must contain 'scale' (E8M0 biased uint8) and 'block_sizes'. + + Returns: + torch.Tensor: Dequantized tensor in the target dtype. + """ + assert "scale" in kwargs, "dequantize requires 'scale' in kwargs" + assert "block_sizes" in kwargs, "dequantize requires 'block_sizes' in kwargs" + + e8m0_scale = kwargs["scale"] + block_size = kwargs["block_sizes"][-1] + + assert block_size > 0, f"block_size must be positive, got {block_size}" + + if dtype is None: + dtype = self.metadata["dtype"] + + original_shape = self.metadata["shape"] + last_dim = original_shape[-1] + quantized_data = self._quantized_data + + # Validate scale shape matches expected number of blocks + padded_last_dim = last_dim + (block_size - last_dim % block_size) % block_size + expected_num_blocks = (quantized_data.numel() // last_dim) * (padded_last_dim // block_size) + assert e8m0_scale.numel() == expected_num_blocks, ( + f"Scale has {e8m0_scale.numel()} elements but expected {expected_num_blocks} blocks" + ) + + # Pad last dimension if not divisible by block_size + if last_dim % block_size != 0: + pad_size = block_size - (last_dim % block_size) + quantized_data = torch.nn.functional.pad( + quantized_data.float(), (0, pad_size), mode="constant", value=0 + ) + else: + quantized_data = quantized_data.float() + + # Flatten to [num_blocks, block_size] for block-wise dequantization + quantized_flat = quantized_data.view(-1, block_size) + + # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) + descale = torch.exp2(e8m0_scale.float() - 127) + + # Flatten scale to (num_blocks, 1) for broadcasting with quantized_flat + descale = descale.view(-1, 1) + + # Apply descale + dequantized = quantized_flat * descale + + # Reshape: account for padded shape, then crop back to original + padded_shape = list(original_shape) + padded_shape[-1] = quantized_data.shape[-1] + dequantized = dequantized.view(padded_shape) + dequantized = dequantized[..., :last_dim] + + return dequantized.to(dtype) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 26df7a8c8..bdbfc6565 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -22,7 +22,7 @@ from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor set_seed() @@ -248,6 +248,14 @@ def test_amax_from_tensor_quantizer( torch.randn([512, 512], dtype=torch.float32), None, ), + # MXFP8 + ( + (4, 3), + {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + None, + torch.randn([512, 512], dtype=torch.float32), + None, + ), ], ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -602,3 +610,330 @@ def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, b assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1) assert hasattr(quantizer, "_scale") assert quantizer._scale.numel() > 1 + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("block_size", [32]) + @pytest.mark.parametrize( + "input_shape", + [(128, 128), (256, 64), (512, 512)], + ) + def test_mxfp8_quantize_dequantize(self, device, input_dtype, block_size, input_shape): + """Test MXFP8 quantization and dequantization produces correct E8M0 scales.""" + # Create test tensor + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize using MXFP8QTensor + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + # Verify scale shape is [out_dim, in_dim // block_size] + expected_scale_shape = (input_shape[0], input_shape[1] // block_size) + assert e8m0_scale.shape == expected_scale_shape, ( + f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" + ) + + # Verify quantized data is FP8 E4M3 + assert qtensor._quantized_data.dtype == torch.float8_e4m3fn, ( + f"Expected float8_e4m3fn, got {qtensor._quantized_data.dtype}" + ) + + # Dequantize + dequant_tensor = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + block_sizes={-1: block_size}, + ) + + # Verify dequantized tensor is close to original + assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: " + f"max diff = {(dequant_tensor - test_tensor).abs().max()}" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_e8m0_scale_values(self, device): + """Test that MXFP8 produces correct E8M0 scale values (power-of-2 only).""" + # Create a tensor with known amax values per block + # Block size is 32, so create a 2x64 tensor (2 rows, 2 blocks per row) + block_size = 32 + test_tensor = torch.zeros((2, 64), dtype=torch.float32, device=device) + + # First block (row 0, elements 0-31): max abs = 1.0, should give exponent ~127-8 = 119 + # (since E4M3 max is 448, log2(1/448) ≈ -8.8, ceil = -8, biased = 127 + (-8) = 119) + test_tensor[0, :32] = 1.0 + + # Second block (row 0, elements 32-63): max abs = 448.0, should give exponent = 127 + # (since 448/448 = 1, log2(1) = 0, biased = 127) + test_tensor[0, 32:64] = 448.0 + + # Third block (row 1, elements 0-31): max abs = 2.0 + test_tensor[1, :32] = 2.0 + + # Fourth block (row 1, elements 32-63): max abs = 0.5 + test_tensor[1, 32:64] = 0.5 + + # Quantize + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + + # Verify all scales are valid uint8 values + assert e8m0_scale.dtype == torch.uint8 + assert e8m0_scale.shape == (2, 2) + + # Verify dequantization works + dequant = qtensor.dequantize( + dtype=torch.float32, + scale=e8m0_scale, + block_sizes={-1: block_size}, + ) + + # Check that the dequantized max values per block are close to original + assert torch.allclose(dequant[0, :32].max(), torch.tensor(1.0, device=device), rtol=0.1) + assert torch.allclose(dequant[0, 32:64].max(), torch.tensor(448.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, :32].max(), torch.tensor(2.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, 32:64].max(), torch.tensor(0.5, device=device), rtol=0.1) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("block_size", [32]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "test_input", + [ + # FP8 E4M3 boundary test values (max is 448, various powers of 2) + torch.tensor( + [ + [ + 1.0, + 2.0, + 4.0, + 8.0, + 16.0, + 32.0, + 64.0, + 128.0, + 256.0, + 448.0, + 0.5, + 0.25, + 0.125, + 0.0625, + 0.03125, + 0.015625, + -1.0, + -2.0, + -4.0, + -8.0, + -16.0, + -32.0, + -64.0, + -128.0, + -256.0, + -448.0, + -0.5, + -0.25, + -0.125, + -0.0625, + -0.03125, + -0.015625, + ] + ] + ), + # Mix of positive and negative values near E4M3 boundaries + torch.tensor( + [ + [ + 448.0, + 416.0, + 384.0, + 352.0, + 320.0, + 288.0, + 256.0, + 224.0, + 192.0, + 160.0, + 128.0, + 96.0, + 64.0, + 48.0, + 32.0, + 24.0, + -448.0, + -416.0, + -384.0, + -352.0, + -320.0, + -288.0, + -256.0, + -224.0, + -192.0, + -160.0, + -128.0, + -96.0, + -64.0, + -48.0, + -32.0, + -24.0, + ] + ] + ), + ], + ) + def test_mxfp8_quantize_boundary_values(self, test_input, device, block_size, input_dtype): + """Test MXFP8 quantization with E4M3 boundary values.""" + x = test_input.to(input_dtype).to(device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(x, block_size) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + dequant = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + block_sizes={-1: block_size}, + ) + + # FP8 E4M3 has limited precision, allow reasonable tolerance + assert torch.allclose(dequant, x, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: max diff = {(dequant - x).abs().max()}" + ) + + @pytest.mark.parametrize( + "input_shape", + [(1600, 1600)], + ) + def test_mxfp8_quantize_gpu_mem(self, input_shape): + """Test MXFP8 GPU memory usage during quantization.""" + + def _get_gpu_mem_used(): + device = torch.device("cuda:0") + free, total = torch.cuda.mem_get_info(device) + return total - free + + block_size = 32 + + # Warmup + test_input = torch.rand((32, 32), dtype=torch.float32, device="cuda") + MXFP8QTensor.quantize(test_input, block_size) + + test_input = torch.rand(input_shape, dtype=torch.float32, device="cuda") + torch.cuda.empty_cache() + + input_size = test_input.element_size() * test_input.numel() + before_quantize = _get_gpu_mem_used() + MXFP8QTensor.quantize(test_input, block_size) + after_quantize = _get_gpu_mem_used() + + # Memory increase should be reasonable (less than 3x input size) + # MXFP8 stores FP8 data (1 byte) + uint8 scales, so should be efficient + assert (after_quantize - before_quantize) < input_size * 3, ( + f"Memory increase too large: {after_quantize - before_quantize} bytes " + f"for input size {input_size} bytes" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("block_size", [32]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "input_shape", + [ + (128, 65), # last dim not divisible by block_size + (256, 100), # last dim not divisible by block_size + (64, 33), # odd number, not divisible by block_size + ], + ) + def test_mxfp8_quantize_with_padding(self, device, block_size, input_dtype, input_shape): + """Test MXFP8 quantization with inputs requiring padding.""" + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize using MXFP8QTensor + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + # Verify quantized data preserves original shape (not padded shape) + assert qtensor._quantized_data.shape == input_shape, ( + f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" + ) + + # Dequantize + dequant_tensor = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + block_sizes={-1: block_size}, + ) + + # Verify dequantized tensor shape matches original + assert dequant_tensor.shape == input_shape, ( + f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" + ) + + # Verify dequantized tensor is close to original + assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: " + f"max diff = {(dequant_tensor - test_tensor).abs().max()}" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("block_size", [32]) + @pytest.mark.parametrize( + "input_shape", + [(128, 64), (256, 128), (512, 256)], + ) + def test_mxfp8_get_weights_scaling_factor(self, device, block_size, input_shape): + """Test MXFP8 get_weights_scaling_factor returns correct E8M0 scales.""" + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + # Get scaling factor + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor(weight, block_size) + + # Verify dtype and shape + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + expected_shape = (input_shape[0], input_shape[1] // block_size) + assert e8m0_scale.shape == expected_shape, ( + f"Expected scale shape {expected_shape}, got {e8m0_scale.shape}" + ) + + # Verify E8M0 values are in valid range [0, 254] (biased exponent = unbiased + 127) + # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254] + # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights + assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)" + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_mxfp8_quantize_3d_tensor(self, device, input_dtype): + """Test MXFP8 quantization with 3D tensors (e.g., MoE experts).""" + block_size = 32 + # Shape: (num_experts, out_dim, in_dim) + input_shape = (4, 64, 128) + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize using MXFP8QTensor + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + + # Verify scale shape: should be (num_experts, out_dim, in_dim // block_size) + expected_scale_shape = (4, 64, 128 // block_size) + assert e8m0_scale.shape == expected_scale_shape, ( + f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" + ) + + # Verify quantized data preserves original shape + assert qtensor._quantized_data.shape == input_shape, ( + f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" + ) + + # Dequantize + dequant_tensor = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + block_sizes={-1: block_size}, + ) + + # Verify dequantized tensor is close to original + assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: " + f"max diff = {(dequant_tensor - test_tensor).abs().max()}" + ) From 235680eb194c342f6467587bab8928415b16b963 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 6 Jan 2026 15:13:10 +0200 Subject: [PATCH 02/11] Refactor MXFP8QTensor and remove irrelevant logic and tests Signed-off-by: Daniel Serebrenik --- modelopt/torch/export/quant_utils.py | 53 +---- modelopt/torch/export/unified_export_hf.py | 38 +--- .../nn/modules/tensor_quantizer.py | 6 +- .../quantization/qtensor/mxfp8_tensor.py | 193 ++++++++++-------- .../torch/quantization/test_qtensor_cuda.py | 145 ++++--------- 5 files changed, 164 insertions(+), 271 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 21459b69a..eb4910570 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -16,7 +16,6 @@ """Utils for quantization including scaling factors adjustments.""" import logging -import math from collections.abc import Generator from types import SimpleNamespace from typing import Any @@ -757,56 +756,6 @@ def pack_int4_in_uint8(weight, weights_scaling_factor): return packed_byte.T.contiguous().view(torch.uint8) -def _quantize_weight_mxfp8( - weight: torch.Tensor, - weights_scaling_factor: torch.Tensor, -) -> torch.Tensor: - """Quantize weight tensor using MXFP8 format. - - MXFP8 uses dynamic block quantization with FP8 (E4M3) along dimension -1 only (1D blocking). - Scales are E8M0 format (power-of-2 only), stored as biased uint8 exponents. - """ - assert weights_scaling_factor is not None, ( - "weights_scaling_factor must be provided for MXFP8 quantization." - ) - - # MXFP8 block size is 32 - block_size = 32 - maxbound = torch.finfo(torch.float8_e4m3fn).max # 448.0 - - out_dim, in_dim = weight.shape[-2], weight.shape[-1] - expected_shape = (out_dim, in_dim // block_size) - - # Reshape scaling factor if needed (same number of elements but wrong shape) - if weights_scaling_factor.shape != expected_shape: - if weights_scaling_factor.numel() == math.prod(expected_shape): - weights_scaling_factor = weights_scaling_factor.reshape(expected_shape) - - # Handle E8M0 uint8 scale format (biased exponent) - if weights_scaling_factor.dtype == torch.uint8: - # E8M0 format: descale = 2^(exponent - 127), scale = 2^(127 - exponent) - scale_factor = torch.exp2(127 - weights_scaling_factor.float()) - else: - # Legacy float32 scale format: scale = amax / maxbound - # Convert to E8M0: exponent = ceil(log2(scale)) - e8m0_exponent = torch.ceil(torch.log2(weights_scaling_factor.clamp(min=2**-127))) - e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) - scale_factor = torch.exp2(-e8m0_exponent) - - # Reshape weight to [out_dim, num_blocks, block_size] - num_blocks = in_dim // block_size - weight_reshaped = weight.view(out_dim, num_blocks, block_size) - - # Apply scale and quantize to FP8 E4M3 - scale_factor_expanded = scale_factor.unsqueeze(-1) - scaled_weight = weight_reshaped * scale_factor_expanded - scaled_weight = torch.clamp(scaled_weight, min=-maxbound, max=maxbound) - quantized_weight = scaled_weight.to(torch.float8_e4m3fn) - - # Reshape back to original 2D shape - return quantized_weight.view(out_dim, in_dim) - - def to_quantized_weight( weight: torch.Tensor, weights_scaling_factor: torch.Tensor, @@ -843,7 +792,7 @@ def to_quantized_weight( return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) if quantization == QUANTIZATION_MXFP8: - return _quantize_weight_mxfp8(weight, weights_scaling_factor) + return MXFP8QTensor.quantize_with_e8m0_scale(weight, weights_scaling_factor) if quantization == QUANTIZATION_FP8_PB_WO: return FP8QTensor.quantize( diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 96cb7b16a..6296bb816 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -32,7 +32,7 @@ from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format @@ -300,39 +300,13 @@ def _export_quantized_weight( del weight_quantizer._scale elif quantization_format == QUANTIZATION_MXFP8: # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) + weight = getattr(sub_module, weight_name) + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, weight_quantizer + ) + sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: - # If _scale is already uint8 (from MXFP8QTensor.quantize), keep it as is - if weight_quantizer._scale.dtype == torch.uint8: - sub_module.register_buffer( - quantizer_attrs.weight_scale, - weight_quantizer._scale, - ) - else: - # Legacy path: convert float32 scale to E8M0 uint8 - # scale = amax / E4M3_max, so exponent = ceil(log2(scale)) - scale = weight_quantizer._scale.to(torch.float32) - e8m0_exponent = torch.ceil(torch.log2(scale.clamp(min=2**-127))) - e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) - e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) - sub_module.register_buffer( - quantizer_attrs.weight_scale, - e8m0_scale, - ) del weight_quantizer._scale - else: - # Compute E8M0 scaling factor from weight tensor - weight = getattr(sub_module, weight_name) - block_size = weight_quantizer.block_sizes[-1] - out_dim, in_dim = weight.shape[-2], weight.shape[-1] - num_blocks = in_dim // block_size - weight_reshaped = weight.view(out_dim, num_blocks, block_size) - amax = weight_reshaped.float().abs().max(dim=-1)[0] - maxbound = torch.finfo(torch.float8_e4m3fn).max - descale = amax / maxbound - e8m0_exponent = torch.ceil(torch.log2(descale.clamp(min=2**-127))) - e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) - e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) - sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) else: sub_module.register_buffer( quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index bc09c4ddf..66bafd47d 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -695,7 +695,11 @@ def _real_quantize(self, inputs): buffer_to_register["_scale"] = scales elif self._num_bits == (4, 3): # MXFP8 - outputs, scales = MXFP8QTensor.quantize(inputs, self._block_sizes[-1]) + assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( + f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " + f"got {self._block_sizes[-1]}" + ) + outputs, scales = MXFP8QTensor.quantize(inputs) buffer_to_register["_scale"] = scales else: raise ValueError( diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index 23e50220e..fb6651421 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -36,6 +36,7 @@ class MXFP8QTensor(BaseQuantizedTensor): E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 BLOCK_SIZE = 32 + SCALE_DTYPE = torch.uint8 # E8M0 format stores biased exponent as uint8 @classmethod def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: @@ -65,41 +66,32 @@ def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: return torch.clamp(e8m0_exponent, min=-127, max=127) @classmethod - def get_weights_scaling_factor( - cls, - weight: torch.Tensor, - block_size: int | None = None, - ) -> torch.Tensor: + def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: """Returns E8M0 scale (uint8 biased exponent) for weight tensor. Args: - weight: The weight tensor to compute scale for. Must be 2D. - block_size: The block size for quantization. Defaults to 32. + weight: The weight tensor to compute scale for. Must be at least 2D. Returns: - torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // block_size]. + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. """ - if block_size is None: - block_size = cls.BLOCK_SIZE - - assert block_size > 0, f"block_size must be positive, got {block_size}" assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" - out_dim, in_dim = weight.shape[-2], weight.shape[-1] + in_dim = weight.shape[-1] - assert in_dim % block_size == 0, ( - f"Weight inner dimension ({in_dim}) must be divisible by block_size ({block_size})" + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" ) - # Reshape to [out_dim, num_blocks, block_size] - weight_reshaped = weight.view(out_dim, in_dim // block_size, block_size) + # Reshape to [..., num_blocks, block_size] + weight_reshaped = weight.view(*weight.shape[:-1], in_dim // cls.BLOCK_SIZE, cls.BLOCK_SIZE) # Compute amax per block amax = weight_reshaped.float().abs().max(dim=-1)[0] # Compute E8M0 exponent and convert to biased uint8 (bias = 127) e8m0_exponent = cls._compute_e8m0_exponent(amax) - return (e8m0_exponent + 127).to(torch.uint8) + return (e8m0_exponent + 127).to(cls.SCALE_DTYPE) @classmethod def get_weights_scaling_factor_from_quantizer( @@ -117,101 +109,137 @@ def get_weights_scaling_factor_from_quantizer( weight_quantizer: The weight quantizer with block_sizes and optional _scale. Returns: - torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // block_size]. + torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // 32]. """ assert hasattr(weight_quantizer, "block_sizes"), ( "weight_quantizer must have 'block_sizes' attribute" ) + assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, ( + f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}" + ) assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" - block_size = weight_quantizer.block_sizes[-1] out_dim, in_dim = weight.shape[-2], weight.shape[-1] - expected_shape = (out_dim, in_dim // block_size) + expected_shape = (out_dim, in_dim // cls.BLOCK_SIZE) if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: scale = weight_quantizer._scale - # If already uint8 E8M0 format, return as-is (with shape correction if needed) - if scale.dtype == torch.uint8: - if ( - scale.shape != expected_shape - and scale.numel() == expected_shape[0] * expected_shape[1] - ): - scale = scale.reshape(expected_shape) - return scale - - # Legacy float32 scale - convert to E8M0 uint8 - if scale.shape != expected_shape: - if scale.numel() == expected_shape[0] * expected_shape[1]: - scale = scale.reshape(expected_shape) - else: - # Shape mismatch, recompute from weight - return cls.get_weights_scaling_factor(weight, block_size) - - # Convert float32 scale to E8M0 uint8 - e8m0_exponent = torch.ceil(torch.log2(scale.clamp(min=2**-127))) - e8m0_exponent = torch.clamp(e8m0_exponent, min=-127, max=127) - return (e8m0_exponent + 127).to(torch.uint8) + assert scale.dtype == cls.SCALE_DTYPE, ( + f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}" + ) + + # Reshape if needed (same number of elements but wrong shape) + if ( + scale.shape != expected_shape + and scale.numel() == expected_shape[0] * expected_shape[1] + ): + scale = scale.reshape(expected_shape) + return scale # No scale in quantizer, compute from weight - return cls.get_weights_scaling_factor(weight, block_size) + return cls.get_weights_scaling_factor(weight) + + @classmethod + def quantize_with_e8m0_scale( + cls, + weight: torch.Tensor, + e8m0_scale: torch.Tensor, + ) -> torch.Tensor: + """Quantize weight tensor using a pre-computed E8M0 scale. + + This method is useful for export paths where the scale has already been computed. + + Args: + weight: The weight tensor to quantize. Must be at least 2D. + e8m0_scale: E8M0 scale as uint8 biased exponent (bias = 127). + Shape should be [..., out_dim, in_dim // 32]. + + Returns: + torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input. + """ + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + assert e8m0_scale.dtype == cls.SCALE_DTYPE, ( + f"e8m0_scale must be {cls.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}" + ) + + in_dim = weight.shape[-1] + num_blocks = in_dim // cls.BLOCK_SIZE + + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" + ) + + # Reshape scale if needed (same number of elements but wrong shape) + expected_shape = (*weight.shape[:-1], num_blocks) + if e8m0_scale.shape != expected_shape: + if e8m0_scale.numel() == weight.numel() // cls.BLOCK_SIZE: + e8m0_scale = e8m0_scale.reshape(expected_shape) + + # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) + scale_factor = torch.exp2(127 - e8m0_scale.float()) + + # NOTE: vLLM/flashinfer may require this behavior: + # scale_factor = torch.where( + # e8m0_scale == 0, + # 1.0, + # torch.exp2(127 - e8m0_scale.float()) + # ) + + # Reshape weight to [..., out_dim, num_blocks, block_size] + weight_reshaped = weight.view(*weight.shape[:-1], num_blocks, cls.BLOCK_SIZE) + + # Apply scale and quantize to FP8 E4M3 + scale_factor_expanded = scale_factor.unsqueeze(-1) + scaled_weight = weight_reshaped * scale_factor_expanded + scaled_weight = torch.clamp(scaled_weight, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) + quantized_weight = scaled_weight.to(torch.float8_e4m3fn) + + # Reshape back to original shape + return quantized_weight.view(weight.shape) @classmethod - def quantize(cls, input: torch.Tensor, block_size: int | None = None) -> tuple: + def quantize(cls, input: torch.Tensor) -> tuple: """Convert a tensor to MXFP8 quantized format. Args: input (torch.Tensor): The input tensor to be quantized. - block_size (int | None): The block size for quantization. Defaults to 32. Returns: tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent. """ - if block_size is None: - block_size = cls.BLOCK_SIZE - - assert block_size > 0, f"block_size must be positive, got {block_size}" assert input.numel() > 0, "Input tensor must not be empty" assert input.dim() >= 1, f"Input must have at least 1 dimension, got {input.dim()}D" + assert input.is_floating_point(), f"Input must be floating point, got {input.dtype}" original_shape = input.shape original_dtype = input.dtype # Pad last dimension if not divisible by block_size last_dim = original_shape[-1] - if last_dim % block_size != 0: - pad_size = block_size - (last_dim % block_size) + if last_dim % cls.BLOCK_SIZE != 0: + pad_size = cls.BLOCK_SIZE - (last_dim % cls.BLOCK_SIZE) input = torch.nn.functional.pad(input, (0, pad_size), mode="constant", value=0) # Flatten to [num_blocks, block_size] for block-wise quantization - input_flat = input.view(-1, block_size) + input_flat = input.view(-1, cls.BLOCK_SIZE) - # Compute amax per block + # Compute amax per block and E8M0 scale input_amax = input_flat.float().abs().max(dim=-1, keepdim=True).values - - # Compute E8M0 exponent and scale factor e8m0_exponent = cls._compute_e8m0_exponent(input_amax) - scale_factor = torch.exp2(-e8m0_exponent) + e8m0_scale = (e8m0_exponent + 127).to(cls.SCALE_DTYPE) - # Apply scale and quantize to FP8 E4M3 - scaled_input = input_flat * scale_factor - - # Clamp to E4M3 range [-448, 448] and convert - scaled_input = torch.clamp(scaled_input, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) - quantized_data = scaled_input.to(torch.float8_e4m3fn) - - # Reshape: account for padded shape, then crop back to original + # Reshape scale to match padded input shape for quantize_with_e8m0_scale padded_shape = list(original_shape) padded_shape[-1] = input.shape[-1] - quantized_data = quantized_data.view(padded_shape) - quantized_data = quantized_data[..., :last_dim] + scale_shape = [*original_shape[:-1], input.shape[-1] // cls.BLOCK_SIZE] + e8m0_scale = e8m0_scale.view(scale_shape) - # Convert exponent to biased uint8 (bias = 127) - e8m0_scale = (e8m0_exponent + 127).to(torch.uint8) + # Use quantize_with_e8m0_scale for the actual quantization (single source of truth) + quantized_data = cls.quantize_with_e8m0_scale(input.view(padded_shape), e8m0_scale) - # Reshape scale to preserve leading dimensions: (*original_shape[:-1], padded_last_dim // block_size) - scale_shape = [*original_shape[:-1], input.shape[-1] // block_size] - e8m0_scale = e8m0_scale.view(scale_shape) + # Crop back to original shape + quantized_data = quantized_data[..., :last_dim] return cls(original_shape, original_dtype, quantized_data), e8m0_scale @@ -220,18 +248,17 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: Args: dtype (torch.dtype | None): Target dtype for dequantization. Defaults to original dtype. - **kwargs: Must contain 'scale' (E8M0 biased uint8) and 'block_sizes'. + **kwargs: Must contain 'scale' (E8M0 biased uint8). Returns: torch.Tensor: Dequantized tensor in the target dtype. """ assert "scale" in kwargs, "dequantize requires 'scale' in kwargs" - assert "block_sizes" in kwargs, "dequantize requires 'block_sizes' in kwargs" e8m0_scale = kwargs["scale"] - block_size = kwargs["block_sizes"][-1] - - assert block_size > 0, f"block_size must be positive, got {block_size}" + assert e8m0_scale.dtype == self.SCALE_DTYPE, ( + f"e8m0_scale must be {self.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}" + ) if dtype is None: dtype = self.metadata["dtype"] @@ -241,15 +268,19 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: quantized_data = self._quantized_data # Validate scale shape matches expected number of blocks - padded_last_dim = last_dim + (block_size - last_dim % block_size) % block_size - expected_num_blocks = (quantized_data.numel() // last_dim) * (padded_last_dim // block_size) + padded_last_dim = ( + last_dim + (self.BLOCK_SIZE - last_dim % self.BLOCK_SIZE) % self.BLOCK_SIZE + ) + expected_num_blocks = (quantized_data.numel() // last_dim) * ( + padded_last_dim // self.BLOCK_SIZE + ) assert e8m0_scale.numel() == expected_num_blocks, ( f"Scale has {e8m0_scale.numel()} elements but expected {expected_num_blocks} blocks" ) # Pad last dimension if not divisible by block_size - if last_dim % block_size != 0: - pad_size = block_size - (last_dim % block_size) + if last_dim % self.BLOCK_SIZE != 0: + pad_size = self.BLOCK_SIZE - (last_dim % self.BLOCK_SIZE) quantized_data = torch.nn.functional.pad( quantized_data.float(), (0, pad_size), mode="constant", value=0 ) @@ -257,7 +288,7 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: quantized_data = quantized_data.float() # Flatten to [num_blocks, block_size] for block-wise dequantization - quantized_flat = quantized_data.view(-1, block_size) + quantized_flat = quantized_data.view(-1, self.BLOCK_SIZE) # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) descale = torch.exp2(e8m0_scale.float() - 127) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index bdbfc6565..ba401ea15 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -15,6 +15,8 @@ """Unit tests for quantized tensors.""" +import math + import pytest import torch from _test_utils.torch.misc import set_seed @@ -613,41 +615,61 @@ def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, b @pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("block_size", [32]) @pytest.mark.parametrize( "input_shape", - [(128, 128), (256, 64), (512, 512)], + [ + (128, 128), + (256, 64), + (512, 512), + # 3D shapes (MoE): (num_experts, out_dim, in_dim) + (4, 64, 128), + (1, 64, 128), # single expert edge case + (32, 256, 512), # large-scale MoE + # Shapes requiring padding (last dim not divisible by block size 32) + (8, 128, 65), # odd in_dim + (128, 65), + (256, 100), + (64, 33), + ], ) - def test_mxfp8_quantize_dequantize(self, device, input_dtype, block_size, input_shape): + def test_mxfp8_quantize_dequantize(self, device, input_dtype, input_shape): """Test MXFP8 quantization and dequantization produces correct E8M0 scales.""" # Create test tensor test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) # Quantize using MXFP8QTensor - qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) # Verify scale is uint8 (E8M0 format) assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" - # Verify scale shape is [out_dim, in_dim // block_size] - expected_scale_shape = (input_shape[0], input_shape[1] // block_size) + # Verify scale shape: last dim is ceil(in_dim / 32), other dims preserved + expected_scale_shape = ( + *input_shape[:-1], + math.ceil(input_shape[-1] / MXFP8QTensor.BLOCK_SIZE), + ) assert e8m0_scale.shape == expected_scale_shape, ( f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" ) - # Verify quantized data is FP8 E4M3 + # Verify quantized data is FP8 E4M3 and preserves original shape assert qtensor._quantized_data.dtype == torch.float8_e4m3fn, ( f"Expected float8_e4m3fn, got {qtensor._quantized_data.dtype}" ) + assert qtensor._quantized_data.shape == input_shape, ( + f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" + ) # Dequantize dequant_tensor = qtensor.dequantize( dtype=input_dtype, scale=e8m0_scale, - block_sizes={-1: block_size}, ) - # Verify dequantized tensor is close to original + # Verify dequantized tensor shape and values match original + assert dequant_tensor.shape == input_shape, ( + f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" + ) assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( f"Dequantized tensor differs from original: " f"max diff = {(dequant_tensor - test_tensor).abs().max()}" @@ -657,8 +679,7 @@ def test_mxfp8_quantize_dequantize(self, device, input_dtype, block_size, input_ def test_mxfp8_e8m0_scale_values(self, device): """Test that MXFP8 produces correct E8M0 scale values (power-of-2 only).""" # Create a tensor with known amax values per block - # Block size is 32, so create a 2x64 tensor (2 rows, 2 blocks per row) - block_size = 32 + # MXFP8 block size is 32, so create a 2x64 tensor (2 rows, 2 blocks per row) test_tensor = torch.zeros((2, 64), dtype=torch.float32, device=device) # First block (row 0, elements 0-31): max abs = 1.0, should give exponent ~127-8 = 119 @@ -676,7 +697,7 @@ def test_mxfp8_e8m0_scale_values(self, device): test_tensor[1, 32:64] = 0.5 # Quantize - qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) # Verify all scales are valid uint8 values assert e8m0_scale.dtype == torch.uint8 @@ -686,7 +707,6 @@ def test_mxfp8_e8m0_scale_values(self, device): dequant = qtensor.dequantize( dtype=torch.float32, scale=e8m0_scale, - block_sizes={-1: block_size}, ) # Check that the dequantized max values per block are close to original @@ -696,7 +716,6 @@ def test_mxfp8_e8m0_scale_values(self, device): assert torch.allclose(dequant[1, 32:64].max(), torch.tensor(0.5, device=device), rtol=0.1) @pytest.mark.parametrize("device", ["cuda"]) - @pytest.mark.parametrize("block_size", [32]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "test_input", @@ -781,10 +800,10 @@ def test_mxfp8_e8m0_scale_values(self, device): ), ], ) - def test_mxfp8_quantize_boundary_values(self, test_input, device, block_size, input_dtype): + def test_mxfp8_quantize_boundary_values(self, test_input, device, input_dtype): """Test MXFP8 quantization with E4M3 boundary values.""" x = test_input.to(input_dtype).to(device) - qtensor, e8m0_scale = MXFP8QTensor.quantize(x, block_size) + qtensor, e8m0_scale = MXFP8QTensor.quantize(x) # Verify scale is uint8 (E8M0 format) assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" @@ -792,7 +811,6 @@ def test_mxfp8_quantize_boundary_values(self, test_input, device, block_size, in dequant = qtensor.dequantize( dtype=input_dtype, scale=e8m0_scale, - block_sizes={-1: block_size}, ) # FP8 E4M3 has limited precision, allow reasonable tolerance @@ -812,18 +830,16 @@ def _get_gpu_mem_used(): free, total = torch.cuda.mem_get_info(device) return total - free - block_size = 32 - # Warmup test_input = torch.rand((32, 32), dtype=torch.float32, device="cuda") - MXFP8QTensor.quantize(test_input, block_size) + MXFP8QTensor.quantize(test_input) test_input = torch.rand(input_shape, dtype=torch.float32, device="cuda") torch.cuda.empty_cache() input_size = test_input.element_size() * test_input.numel() before_quantize = _get_gpu_mem_used() - MXFP8QTensor.quantize(test_input, block_size) + MXFP8QTensor.quantize(test_input) after_quantize = _get_gpu_mem_used() # Memory increase should be reasonable (less than 3x input size) @@ -834,65 +850,20 @@ def _get_gpu_mem_used(): ) @pytest.mark.parametrize("device", ["cuda"]) - @pytest.mark.parametrize("block_size", [32]) - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) - @pytest.mark.parametrize( - "input_shape", - [ - (128, 65), # last dim not divisible by block_size - (256, 100), # last dim not divisible by block_size - (64, 33), # odd number, not divisible by block_size - ], - ) - def test_mxfp8_quantize_with_padding(self, device, block_size, input_dtype, input_shape): - """Test MXFP8 quantization with inputs requiring padding.""" - test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) - - # Quantize using MXFP8QTensor - qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) - - # Verify scale is uint8 (E8M0 format) - assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" - - # Verify quantized data preserves original shape (not padded shape) - assert qtensor._quantized_data.shape == input_shape, ( - f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" - ) - - # Dequantize - dequant_tensor = qtensor.dequantize( - dtype=input_dtype, - scale=e8m0_scale, - block_sizes={-1: block_size}, - ) - - # Verify dequantized tensor shape matches original - assert dequant_tensor.shape == input_shape, ( - f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" - ) - - # Verify dequantized tensor is close to original - assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( - f"Dequantized tensor differs from original: " - f"max diff = {(dequant_tensor - test_tensor).abs().max()}" - ) - - @pytest.mark.parametrize("device", ["cuda"]) - @pytest.mark.parametrize("block_size", [32]) @pytest.mark.parametrize( "input_shape", [(128, 64), (256, 128), (512, 256)], ) - def test_mxfp8_get_weights_scaling_factor(self, device, block_size, input_shape): + def test_mxfp8_get_weights_scaling_factor(self, device, input_shape): """Test MXFP8 get_weights_scaling_factor returns correct E8M0 scales.""" weight = torch.randn(input_shape, dtype=torch.float32, device=device) # Get scaling factor - e8m0_scale = MXFP8QTensor.get_weights_scaling_factor(weight, block_size) + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor(weight) # Verify dtype and shape assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" - expected_shape = (input_shape[0], input_shape[1] // block_size) + expected_shape = (input_shape[0], input_shape[1] // MXFP8QTensor.BLOCK_SIZE) assert e8m0_scale.shape == expected_shape, ( f"Expected scale shape {expected_shape}, got {e8m0_scale.shape}" ) @@ -901,39 +872,3 @@ def test_mxfp8_get_weights_scaling_factor(self, device, block_size, input_shape) # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254] # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)" - - @pytest.mark.parametrize("device", ["cuda"]) - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_mxfp8_quantize_3d_tensor(self, device, input_dtype): - """Test MXFP8 quantization with 3D tensors (e.g., MoE experts).""" - block_size = 32 - # Shape: (num_experts, out_dim, in_dim) - input_shape = (4, 64, 128) - test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) - - # Quantize using MXFP8QTensor - qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor, block_size) - - # Verify scale shape: should be (num_experts, out_dim, in_dim // block_size) - expected_scale_shape = (4, 64, 128 // block_size) - assert e8m0_scale.shape == expected_scale_shape, ( - f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" - ) - - # Verify quantized data preserves original shape - assert qtensor._quantized_data.shape == input_shape, ( - f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" - ) - - # Dequantize - dequant_tensor = qtensor.dequantize( - dtype=input_dtype, - scale=e8m0_scale, - block_sizes={-1: block_size}, - ) - - # Verify dequantized tensor is close to original - assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( - f"Dequantized tensor differs from original: " - f"max diff = {(dequant_tensor - test_tensor).abs().max()}" - ) From 9ddc71ef5cd8f974e7531e830642a143c3ed8601 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 6 Jan 2026 15:14:00 +0200 Subject: [PATCH 03/11] Add mxfp8 to test_llm_ptq and huggingface_example Signed-off-by: Daniel Serebrenik --- examples/llm_ptq/scripts/huggingface_example.sh | 4 ++-- tests/examples/llm_ptq/test_llm_ptq.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..ec415b2f7 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | mxfp8) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, mxfp8]" >&2 exit 1 ;; esac diff --git a/tests/examples/llm_ptq/test_llm_ptq.py b/tests/examples/llm_ptq/test_llm_ptq.py index 6ba23cc04..4fc39f5ec 100644 --- a/tests/examples/llm_ptq/test_llm_ptq.py +++ b/tests/examples/llm_ptq/test_llm_ptq.py @@ -114,6 +114,7 @@ def test_ptq_whisper(self, command): # sm89 PTQCommand(quant="fp8", min_sm=89), PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100 + PTQCommand(quant="mxfp8", min_sm=100), PTQCommand(quant="nvfp4", min_sm=100), # # multi_gpu From beb3ba6cb5b4b9d0c656f831b5b55b087cdb11f5 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 6 Jan 2026 20:42:49 +0200 Subject: [PATCH 04/11] Use existing utils functions in MXFP8QTensor Signed-off-by: Daniel Serebrenik --- modelopt/torch/export/quant_utils.py | 2 +- .../quantization/qtensor/mxfp8_tensor.py | 89 ++++++------------- 2 files changed, 26 insertions(+), 65 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eb4910570..87b1018c3 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -792,7 +792,7 @@ def to_quantized_weight( return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) if quantization == QUANTIZATION_MXFP8: - return MXFP8QTensor.quantize_with_e8m0_scale(weight, weights_scaling_factor) + return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor) if quantization == QUANTIZATION_FP8_PB_WO: return FP8QTensor.quantize( diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index fb6651421..fa9345ab1 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -18,6 +18,7 @@ import torch from ..qtensor.base_qtensor import BaseQuantizedTensor +from ..utils import reduce_block_amax, reduce_block_padding __all__ = ["MXFP8QTensor"] @@ -49,7 +50,7 @@ def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: torch.Tensor: Float tensor of E8M0 exponents (unbiased, range [-127, 127]). """ # Compute E8M0 scale: scale = 2^ceil(log2(amax / E4M3_max)) - descale = amax / cls.E4M3_MAX + descale = amax.float() / cls.E4M3_MAX # Handle zero/inf/nan cases min_value = torch.tensor(-127.0, device=descale.device) @@ -59,10 +60,9 @@ def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: min_value, ) - # Ceil to get power-of-2 scale e8m0_exponent = torch.ceil(log2_descale) - # Clamp exponent to valid E8M0 range [-127, 127] + # Clamp exponent to valid E8M0 range return torch.clamp(e8m0_exponent, min=-127, max=127) @classmethod @@ -83,11 +83,8 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" ) - # Reshape to [..., num_blocks, block_size] - weight_reshaped = weight.view(*weight.shape[:-1], in_dim // cls.BLOCK_SIZE, cls.BLOCK_SIZE) - - # Compute amax per block - amax = weight_reshaped.float().abs().max(dim=-1)[0] + # Compute amax per block (reduce_block_amax handles reshaping internally) + amax = reduce_block_amax(weight, block_sizes={-1: cls.BLOCK_SIZE}) # Compute E8M0 exponent and convert to biased uint8 (bias = 127) e8m0_exponent = cls._compute_e8m0_exponent(amax) @@ -141,7 +138,7 @@ def get_weights_scaling_factor_from_quantizer( return cls.get_weights_scaling_factor(weight) @classmethod - def quantize_with_e8m0_scale( + def quantize_with_scale( cls, weight: torch.Tensor, e8m0_scale: torch.Tensor, @@ -186,16 +183,12 @@ def quantize_with_e8m0_scale( # torch.exp2(127 - e8m0_scale.float()) # ) - # Reshape weight to [..., out_dim, num_blocks, block_size] weight_reshaped = weight.view(*weight.shape[:-1], num_blocks, cls.BLOCK_SIZE) - - # Apply scale and quantize to FP8 E4M3 scale_factor_expanded = scale_factor.unsqueeze(-1) scaled_weight = weight_reshaped * scale_factor_expanded scaled_weight = torch.clamp(scaled_weight, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) quantized_weight = scaled_weight.to(torch.float8_e4m3fn) - # Reshape back to original shape return quantized_weight.view(weight.shape) @classmethod @@ -215,31 +208,16 @@ def quantize(cls, input: torch.Tensor) -> tuple: original_shape = input.shape original_dtype = input.dtype - # Pad last dimension if not divisible by block_size - last_dim = original_shape[-1] - if last_dim % cls.BLOCK_SIZE != 0: - pad_size = cls.BLOCK_SIZE - (last_dim % cls.BLOCK_SIZE) - input = torch.nn.functional.pad(input, (0, pad_size), mode="constant", value=0) - - # Flatten to [num_blocks, block_size] for block-wise quantization - input_flat = input.view(-1, cls.BLOCK_SIZE) + input = reduce_block_padding(input, block_sizes={-1: cls.BLOCK_SIZE}) + input_amax = reduce_block_amax(input, block_sizes={-1: cls.BLOCK_SIZE}) - # Compute amax per block and E8M0 scale - input_amax = input_flat.float().abs().max(dim=-1, keepdim=True).values e8m0_exponent = cls._compute_e8m0_exponent(input_amax) e8m0_scale = (e8m0_exponent + 127).to(cls.SCALE_DTYPE) - # Reshape scale to match padded input shape for quantize_with_e8m0_scale - padded_shape = list(original_shape) - padded_shape[-1] = input.shape[-1] - scale_shape = [*original_shape[:-1], input.shape[-1] // cls.BLOCK_SIZE] - e8m0_scale = e8m0_scale.view(scale_shape) - - # Use quantize_with_e8m0_scale for the actual quantization (single source of truth) - quantized_data = cls.quantize_with_e8m0_scale(input.view(padded_shape), e8m0_scale) + quantized_data = cls.quantize_with_scale(input, e8m0_scale) # Crop back to original shape - quantized_data = quantized_data[..., :last_dim] + quantized_data = quantized_data[..., : original_shape[-1]] return cls(original_shape, original_dtype, quantized_data), e8m0_scale @@ -264,45 +242,28 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: dtype = self.metadata["dtype"] original_shape = self.metadata["shape"] - last_dim = original_shape[-1] - quantized_data = self._quantized_data + quantized_data = self._quantized_data.float() + quantized_data = reduce_block_padding(quantized_data, block_sizes={-1: self.BLOCK_SIZE}) - # Validate scale shape matches expected number of blocks - padded_last_dim = ( - last_dim + (self.BLOCK_SIZE - last_dim % self.BLOCK_SIZE) % self.BLOCK_SIZE - ) - expected_num_blocks = (quantized_data.numel() // last_dim) * ( - padded_last_dim // self.BLOCK_SIZE + num_blocks = quantized_data.shape[-1] // self.BLOCK_SIZE + quantized_blocked = quantized_data.view( + *quantized_data.shape[:-1], num_blocks, self.BLOCK_SIZE ) - assert e8m0_scale.numel() == expected_num_blocks, ( - f"Scale has {e8m0_scale.numel()} elements but expected {expected_num_blocks} blocks" - ) - - # Pad last dimension if not divisible by block_size - if last_dim % self.BLOCK_SIZE != 0: - pad_size = self.BLOCK_SIZE - (last_dim % self.BLOCK_SIZE) - quantized_data = torch.nn.functional.pad( - quantized_data.float(), (0, pad_size), mode="constant", value=0 - ) - else: - quantized_data = quantized_data.float() - - # Flatten to [num_blocks, block_size] for block-wise dequantization - quantized_flat = quantized_data.view(-1, self.BLOCK_SIZE) # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) descale = torch.exp2(e8m0_scale.float() - 127) - # Flatten scale to (num_blocks, 1) for broadcasting with quantized_flat - descale = descale.view(-1, 1) + # Reshape descale to match blocked tensor for broadcasting + expected_scale_shape = (*quantized_data.shape[:-1], num_blocks) + if descale.shape != expected_scale_shape and descale.numel() == num_blocks * ( + quantized_data.numel() // quantized_data.shape[-1] + ): + descale = descale.view(expected_scale_shape) - # Apply descale - dequantized = quantized_flat * descale + dequantized = quantized_blocked * descale.unsqueeze(-1) - # Reshape: account for padded shape, then crop back to original - padded_shape = list(original_shape) - padded_shape[-1] = quantized_data.shape[-1] - dequantized = dequantized.view(padded_shape) - dequantized = dequantized[..., :last_dim] + # Reshape and crop back to original shape + dequantized = dequantized.view(*quantized_data.shape[:-1], quantized_data.shape[-1]) + dequantized = dequantized[..., : original_shape[-1]] return dequantized.to(dtype) From 8f21cd82d14f032cbcf3dfc22726157b44e81751 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 7 Jan 2026 20:27:07 +0200 Subject: [PATCH 05/11] Improve formatting in test_mxfp8_quantize_boundary_values Signed-off-by: Daniel Serebrenik --- .../torch/quantization/test_qtensor_cuda.py | 84 ++----------------- 1 file changed, 8 insertions(+), 76 deletions(-) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index ba401ea15..69d512c5b 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -715,92 +715,24 @@ def test_mxfp8_e8m0_scale_values(self, device): assert torch.allclose(dequant[1, :32].max(), torch.tensor(2.0, device=device), rtol=0.1) assert torch.allclose(dequant[1, 32:64].max(), torch.tensor(0.5, device=device), rtol=0.1) + # fmt: off @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "test_input", [ # FP8 E4M3 boundary test values (max is 448, various powers of 2) - torch.tensor( - [ - [ - 1.0, - 2.0, - 4.0, - 8.0, - 16.0, - 32.0, - 64.0, - 128.0, - 256.0, - 448.0, - 0.5, - 0.25, - 0.125, - 0.0625, - 0.03125, - 0.015625, - -1.0, - -2.0, - -4.0, - -8.0, - -16.0, - -32.0, - -64.0, - -128.0, - -256.0, - -448.0, - -0.5, - -0.25, - -0.125, - -0.0625, - -0.03125, - -0.015625, - ] - ] - ), + torch.tensor([[1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0, 0.5, 0.25, + 0.125, 0.0625, 0.03125, 0.015625, -1.0, -2.0, -4.0, -8.0, -16.0, -32.0, + -64.0, -128.0, -256.0, -448.0, -0.5, -0.25, -0.125, -0.0625, -0.03125, -0.015625]]), # Mix of positive and negative values near E4M3 boundaries - torch.tensor( - [ - [ - 448.0, - 416.0, - 384.0, - 352.0, - 320.0, - 288.0, - 256.0, - 224.0, - 192.0, - 160.0, - 128.0, - 96.0, - 64.0, - 48.0, - 32.0, - 24.0, - -448.0, - -416.0, - -384.0, - -352.0, - -320.0, - -288.0, - -256.0, - -224.0, - -192.0, - -160.0, - -128.0, - -96.0, - -64.0, - -48.0, - -32.0, - -24.0, - ] - ] - ), + torch.tensor([[448.0, 416.0, 384.0, 352.0, 320.0, 288.0, 256.0, 224.0, 192.0, 160.0, + 128.0, 96.0, 64.0, 48.0, 32.0, 24.0, -448.0, -416.0, -384.0, -352.0, -320.0, + -288.0, -256.0, -224.0, -192.0, -160.0, -128.0, -96.0, -64.0, -48.0, -32.0, -24.0]]), ], ) def test_mxfp8_quantize_boundary_values(self, test_input, device, input_dtype): + # fmt: on """Test MXFP8 quantization with E4M3 boundary values.""" x = test_input.to(input_dtype).to(device) qtensor, e8m0_scale = MXFP8QTensor.quantize(x) From a764b32af1f4ad4f3aa11b27f6f82cdd030db11c Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Wed, 7 Jan 2026 20:59:07 +0200 Subject: [PATCH 06/11] Add more tests for MXFP8 (error handling) Signed-off-by: Daniel Serebrenik --- .../torch/quantization/test_qtensor_cuda.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 69d512c5b..c3309832d 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -804,3 +804,86 @@ def test_mxfp8_get_weights_scaling_factor(self, device, input_shape): # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254] # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)" + + @pytest.mark.parametrize( + ("amax_value", "expected_exponent"), + [ + (0.0, -127.0), # Zero amax: minimum exponent + (448.0, 0.0), # E4M3_MAX: exponent 0 + (1.0, -8.0), # log2(1/448) ~ -8.8, ceil = -8 + (1e40, 127.0), # Very large amax: clamps to max + (1e-50, -127.0), # Very small amax: clamps to min + ], + ) + def test_mxfp8_compute_e8m0_exponent_edge_cases(self, amax_value, expected_exponent): + """Test _compute_e8m0_exponent handles edge cases correctly.""" + amax = torch.tensor([amax_value], device="cuda") + exponent = MXFP8QTensor._compute_e8m0_exponent(amax) + assert exponent.item() == expected_exponent, ( + f"amax={amax_value} should give exponent {expected_exponent}, got {exponent.item()}" + ) + + def test_mxfp8_get_weights_scaling_factor_asserts_1d_weight(self): + """Test get_weights_scaling_factor raises assertion for 1D tensor.""" + weight_1d = torch.randn(64, device="cuda") + with pytest.raises(AssertionError, match="Weight must be at least 2D"): + MXFP8QTensor.get_weights_scaling_factor(weight_1d) + + def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self): + """Test get_weights_scaling_factor raises assertion when dim not divisible by 32.""" + # 33 is not divisible by 32 + weight = torch.randn(64, 33, device="cuda") + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.get_weights_scaling_factor(weight) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_quantize_with_scale_asserts(self, device): + """Test quantize_with_scale raises assertions for invalid inputs.""" + # Test 1D weight assertion + weight_1d = torch.randn(64, dtype=torch.float32, device=device) + scale = torch.randint(0, 255, (2,), dtype=torch.uint8, device=device) + with pytest.raises(AssertionError, match="Weight must be at least 2D"): + MXFP8QTensor.quantize_with_scale(weight_1d, scale) + + # Test wrong scale dtype assertion + weight = torch.randn(64, 64, dtype=torch.float32, device=device) + wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="e8m0_scale must be"): + MXFP8QTensor.quantize_with_scale(weight, wrong_dtype_scale) + + # Test non-divisible dimension assertion + weight_bad_dim = torch.randn(64, 33, dtype=torch.float32, device=device) + scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device) + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_quantize_dequantize_asserts(self, device): + """Test quantize and dequantize raise assertions for invalid inputs.""" + # Test empty tensor assertion + empty_tensor = torch.empty(0, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="Input tensor must not be empty"): + MXFP8QTensor.quantize(empty_tensor) + + # Test 0D tensor assertion + scalar_tensor = torch.tensor(1.0, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="Input must have at least 1 dimension"): + MXFP8QTensor.quantize(scalar_tensor) + + # Test non-floating point assertion + int_tensor = torch.randint(0, 10, (32, 32), dtype=torch.int32, device=device) + with pytest.raises(AssertionError, match="Input must be floating point"): + MXFP8QTensor.quantize(int_tensor) + + # Create a valid quantized tensor for dequantize tests + input_tensor = torch.randn(64, 64, dtype=torch.float32, device=device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor) + + # Test missing scale assertion + with pytest.raises(AssertionError, match="dequantize requires 'scale' in kwargs"): + qtensor.dequantize(dtype=torch.float32) + + # Test wrong scale dtype assertion + wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="e8m0_scale must be"): + qtensor.dequantize(dtype=torch.float32, scale=wrong_dtype_scale) From 43d3591be5ad0bb3ae2511ab1a240b8e410203eb Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 15 Jan 2026 11:06:47 +0200 Subject: [PATCH 07/11] Remove excessive assertions in MXFP8QTensor Signed-off-by: Daniel Serebrenik --- .../quantization/qtensor/mxfp8_tensor.py | 13 ++----- .../torch/quantization/test_qtensor_cuda.py | 37 ------------------- 2 files changed, 3 insertions(+), 47 deletions(-) diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index fa9345ab1..0c8439157 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -148,14 +148,14 @@ def quantize_with_scale( This method is useful for export paths where the scale has already been computed. Args: - weight: The weight tensor to quantize. Must be at least 2D. + weight: The weight tensor to quantize. Must be at least 1D. e8m0_scale: E8M0 scale as uint8 biased exponent (bias = 127). - Shape should be [..., out_dim, in_dim // 32]. + Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors, + or [in_dim // 32] for 1D tensors. Returns: torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input. """ - assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" assert e8m0_scale.dtype == cls.SCALE_DTYPE, ( f"e8m0_scale must be {cls.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}" ) @@ -201,10 +201,6 @@ def quantize(cls, input: torch.Tensor) -> tuple: Returns: tuple: (MXFP8QTensor, e8m0_scale) where e8m0_scale is uint8 biased exponent. """ - assert input.numel() > 0, "Input tensor must not be empty" - assert input.dim() >= 1, f"Input must have at least 1 dimension, got {input.dim()}D" - assert input.is_floating_point(), f"Input must be floating point, got {input.dtype}" - original_shape = input.shape original_dtype = input.dtype @@ -234,9 +230,6 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: assert "scale" in kwargs, "dequantize requires 'scale' in kwargs" e8m0_scale = kwargs["scale"] - assert e8m0_scale.dtype == self.SCALE_DTYPE, ( - f"e8m0_scale must be {self.SCALE_DTYPE} (E8M0 format), got {e8m0_scale.dtype}" - ) if dtype is None: dtype = self.metadata["dtype"] diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index c3309832d..7a1733e18 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -839,12 +839,6 @@ def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self): @pytest.mark.parametrize("device", ["cuda"]) def test_mxfp8_quantize_with_scale_asserts(self, device): """Test quantize_with_scale raises assertions for invalid inputs.""" - # Test 1D weight assertion - weight_1d = torch.randn(64, dtype=torch.float32, device=device) - scale = torch.randint(0, 255, (2,), dtype=torch.uint8, device=device) - with pytest.raises(AssertionError, match="Weight must be at least 2D"): - MXFP8QTensor.quantize_with_scale(weight_1d, scale) - # Test wrong scale dtype assertion weight = torch.randn(64, 64, dtype=torch.float32, device=device) wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) @@ -856,34 +850,3 @@ def test_mxfp8_quantize_with_scale_asserts(self, device): scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device) with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale) - - @pytest.mark.parametrize("device", ["cuda"]) - def test_mxfp8_quantize_dequantize_asserts(self, device): - """Test quantize and dequantize raise assertions for invalid inputs.""" - # Test empty tensor assertion - empty_tensor = torch.empty(0, dtype=torch.float32, device=device) - with pytest.raises(AssertionError, match="Input tensor must not be empty"): - MXFP8QTensor.quantize(empty_tensor) - - # Test 0D tensor assertion - scalar_tensor = torch.tensor(1.0, dtype=torch.float32, device=device) - with pytest.raises(AssertionError, match="Input must have at least 1 dimension"): - MXFP8QTensor.quantize(scalar_tensor) - - # Test non-floating point assertion - int_tensor = torch.randint(0, 10, (32, 32), dtype=torch.int32, device=device) - with pytest.raises(AssertionError, match="Input must be floating point"): - MXFP8QTensor.quantize(int_tensor) - - # Create a valid quantized tensor for dequantize tests - input_tensor = torch.randn(64, 64, dtype=torch.float32, device=device) - qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor) - - # Test missing scale assertion - with pytest.raises(AssertionError, match="dequantize requires 'scale' in kwargs"): - qtensor.dequantize(dtype=torch.float32) - - # Test wrong scale dtype assertion - wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) - with pytest.raises(AssertionError, match="e8m0_scale must be"): - qtensor.dequantize(dtype=torch.float32, scale=wrong_dtype_scale) From b17bca9559f861b9f7689d04bdc8ce995fa68ea6 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 15 Jan 2026 11:17:17 +0200 Subject: [PATCH 08/11] Add support and tests for 3D MoE in MXFP8QTensor Signed-off-by: Daniel Serebrenik --- .../quantization/qtensor/mxfp8_tensor.py | 32 +++++++++---- .../torch/quantization/test_qtensor_cuda.py | 47 +++++++++++++++++++ 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index 0c8439157..6b0931726 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -71,9 +71,12 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: Args: weight: The weight tensor to compute scale for. Must be at least 2D. + Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim). Returns: torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. + For 2D input: (out_dim, in_dim // 32) + For 3D MoE input: (num_experts, out_dim, in_dim // 32) """ assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" @@ -83,7 +86,7 @@ def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" ) - # Compute amax per block (reduce_block_amax handles reshaping internally) + # Compute amax per block (reduce_block_amax handles N-dimensional tensors) amax = reduce_block_amax(weight, block_sizes={-1: cls.BLOCK_SIZE}) # Compute E8M0 exponent and convert to biased uint8 (bias = 127) @@ -102,11 +105,12 @@ def get_weights_scaling_factor_from_quantizer( with proper format conversion and shape correction. Args: - weight: The weight tensor. + weight: The weight tensor. Can be 2D (out_dim, in_dim) or + 3D for MoE (num_experts, out_dim, in_dim). weight_quantizer: The weight quantizer with block_sizes and optional _scale. Returns: - torch.Tensor: E8M0 scale as uint8 tensor with shape [out_dim, in_dim // 32]. + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. """ assert hasattr(weight_quantizer, "block_sizes"), ( "weight_quantizer must have 'block_sizes' attribute" @@ -116,8 +120,11 @@ def get_weights_scaling_factor_from_quantizer( ) assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" - out_dim, in_dim = weight.shape[-2], weight.shape[-1] - expected_shape = (out_dim, in_dim // cls.BLOCK_SIZE) + in_dim = weight.shape[-1] + # Expected scale shape: all dims except last, with last dim reduced by block size + # For 2D: (out_dim, in_dim // 32) + # For 3D MoE: (num_experts, out_dim, in_dim // 32) + expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE) if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: scale = weight_quantizer._scale @@ -127,11 +134,16 @@ def get_weights_scaling_factor_from_quantizer( ) # Reshape if needed (same number of elements but wrong shape) - if ( - scale.shape != expected_shape - and scale.numel() == expected_shape[0] * expected_shape[1] - ): - scale = scale.reshape(expected_shape) + if scale.shape != expected_shape: + expected_numel = 1 + for dim in expected_shape: + expected_numel *= dim + if scale.numel() == expected_numel: + scale = scale.reshape(expected_shape) + + assert scale.shape == expected_shape, ( + f"Scale shape {scale.shape} does not match expected shape {expected_shape}" + ) return scale # No scale in quantizer, compute from weight diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 7a1733e18..c995124e6 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -850,3 +850,50 @@ def test_mxfp8_quantize_with_scale_asserts(self, device): scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device) with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_3d_moe(self, device): + """Test get_weights_scaling_factor_from_quantizer handles 3D MoE tensors.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + _scale = None + + quantizer = MockQuantizer() + + # Test when _scale is None (should compute from weight) + scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + expected_shape = ( + input_shape[0], + input_shape[1], + input_shape[2] // MXFP8QTensor.BLOCK_SIZE, + ) + assert scale.shape == expected_shape + + # Test when _scale is provided with correct 3D shape + quantizer._scale = torch.randint(0, 255, expected_shape, dtype=torch.uint8, device=device) + scale_from_quantizer = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, quantizer + ) + assert torch.equal(scale_from_quantizer, quantizer._scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_scale_shape_mismatch(self, device): + """Test get_weights_scaling_factor_from_quantizer raises assertion on shape mismatch.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + # Wrong shape: 2D instead of 3D (missing num_experts dimension) + _scale = torch.randint( + 0, 255, (64, 4), dtype=torch.uint8, device=device + ) + + quantizer = MockQuantizer() + + with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"): + MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) From 626f18e0d0237b4487b95a4f8ef0d4457114826a Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 15 Jan 2026 11:32:35 +0200 Subject: [PATCH 09/11] Cleanup code in get_weights_scaling_factor_from_quantizer of MXFP8QTensor Signed-off-by: Daniel Serebrenik --- .../quantization/qtensor/mxfp8_tensor.py | 22 ------------------- .../torch/quantization/test_qtensor_cuda.py | 12 ++++++++++ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py index 6b0931726..e87612f3a 100644 --- a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -132,15 +132,6 @@ def get_weights_scaling_factor_from_quantizer( assert scale.dtype == cls.SCALE_DTYPE, ( f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}" ) - - # Reshape if needed (same number of elements but wrong shape) - if scale.shape != expected_shape: - expected_numel = 1 - for dim in expected_shape: - expected_numel *= dim - if scale.numel() == expected_numel: - scale = scale.reshape(expected_shape) - assert scale.shape == expected_shape, ( f"Scale shape {scale.shape} does not match expected shape {expected_shape}" ) @@ -179,12 +170,6 @@ def quantize_with_scale( f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" ) - # Reshape scale if needed (same number of elements but wrong shape) - expected_shape = (*weight.shape[:-1], num_blocks) - if e8m0_scale.shape != expected_shape: - if e8m0_scale.numel() == weight.numel() // cls.BLOCK_SIZE: - e8m0_scale = e8m0_scale.reshape(expected_shape) - # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) scale_factor = torch.exp2(127 - e8m0_scale.float()) @@ -258,13 +243,6 @@ def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) descale = torch.exp2(e8m0_scale.float() - 127) - # Reshape descale to match blocked tensor for broadcasting - expected_scale_shape = (*quantized_data.shape[:-1], num_blocks) - if descale.shape != expected_scale_shape and descale.numel() == num_blocks * ( - quantized_data.numel() // quantized_data.shape[-1] - ): - descale = descale.view(expected_scale_shape) - dequantized = quantized_blocked * descale.unsqueeze(-1) # Reshape and crop back to original shape diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index c995124e6..269f0fa63 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -897,3 +897,15 @@ class MockQuantizer: with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"): MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_mxfp8_dequantize_default_dtype(self, device, input_dtype): + """Test dequantize uses original dtype when dtype=None.""" + input_tensor = torch.randn(64, 64, dtype=input_dtype, device=device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor) + + # Dequantize without specifying dtype + dequant = qtensor.dequantize(scale=e8m0_scale) + + assert dequant.dtype == input_dtype From 1034e38b1b74d4305e0147e6c4284e0f130ad8d5 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 15 Jan 2026 21:22:20 +0200 Subject: [PATCH 10/11] Fix TensorQuantizer to handle MXFP8. Tested by test_qtensor_accuracy. Signed-off-by: Daniel Serebrenik --- .../nn/modules/tensor_quantizer.py | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 66bafd47d..12fe28626 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -650,8 +650,31 @@ def _real_quantize(self, inputs): assert self._is_real_quantize_support(), "Real quantization not supported for this format." buffer_to_register = {} - if self._num_bits == (4, 3): - # FP8 quantization + # Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3) + if ( + self._block_sizes + and self._block_sizes.get("scale_bits") == (8, 0) + and self._block_sizes.get("type") == "dynamic" + ): + # MX quantization (MXFP4/MXFP8) + if self._num_bits == (2, 1): + # MXFP4 + outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) + buffer_to_register["_scale"] = scales + elif self._num_bits == (4, 3): + # MXFP8 + assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( + f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " + f"got {self._block_sizes[-1]}" + ) + outputs, scales = MXFP8QTensor.quantize(inputs) + buffer_to_register["_scale"] = scales + else: + raise ValueError( + f"Real quantization for MX {self._num_bits} format is not supported." + ) + elif self._num_bits == (4, 3): + # FP8 quantization (non-MX) # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks # For blockwise quantization, amax will be recomputed in the kernel use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1) @@ -684,27 +707,6 @@ def _real_quantize(self, inputs): buffer_to_register["_scale"] = _scale buffer_to_register["_double_scale"] = _double_scale buffer_to_register["_scale_zeros"] = _scale_zeros - elif ( - self._block_sizes.get("scale_bits") == (8, 0) - and self._block_sizes.get("type") == "dynamic" - ): - # MX quantization - if self._num_bits == (2, 1): - # MXFP4 - outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) - buffer_to_register["_scale"] = scales - elif self._num_bits == (4, 3): - # MXFP8 - assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( - f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " - f"got {self._block_sizes[-1]}" - ) - outputs, scales = MXFP8QTensor.quantize(inputs) - buffer_to_register["_scale"] = scales - else: - raise ValueError( - f"Real quantization for MX {self._num_bits} format is not supported." - ) elif self._block_sizes.get("scale_bits") == (4, 3): # NVFP4 default quantization # Return real quantized tensor and store scales inside TensorQuantizer From b7ed5ce3154b7ad3d32fe3c00a48b30232fbd543 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Thu, 15 Jan 2026 21:37:17 +0200 Subject: [PATCH 11/11] Improve error message in _real_quantize Signed-off-by: Daniel Serebrenik --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 12fe28626..0dde20eec 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -671,7 +671,8 @@ def _real_quantize(self, inputs): buffer_to_register["_scale"] = scales else: raise ValueError( - f"Real quantization for MX {self._num_bits} format is not supported." + f"Unsupported MX format: num_bits={self._num_bits}. " + f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8." ) elif self._num_bits == (4, 3): # FP8 quantization (non-MX)