From 51f5e86aaf90bc422593efdee44843167f6fb194 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 15 Jan 2026 19:32:57 +0000 Subject: [PATCH 1/3] add FP8 scale sweep option Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 53 ++++++++++++--- modelopt/torch/quantization/config.py | 34 ++++++++++ modelopt/torch/quantization/model_calib.py | 13 ++++ .../torch/quantization/test_quantize_cuda.py | 6 +- .../torch/quantization/test_mse_calibrator.py | 67 +++++++++++++++++++ 5 files changed, 163 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index c94b7d716..d99d27525 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -39,6 +39,7 @@ def __init__( stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + fp8_scale_sweep: bool = False, ): """Initialize MSE calibrator. @@ -46,13 +47,16 @@ def __init__( amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. step_size: Step size for amax search. The number of steps is computed as - ceil((stop_multiplier - start_multiplier) / step_size) + 1. + ceil((stop_multiplier - start_multiplier) / step_size) + 1. start_multiplier: Starting multiplier for amax search. stop_multiplier: Ending multiplier for amax search. quant_func: Function that quantizes input tensor given an amax value. - Should have signature: quant_func(x, amax) -> quantized_x. + Should have signature: quant_func(x, amax) -> quantized_x. error_func: Function to compute error between x and xq. - Default is F.mse_loss(x, xq, reduction='none'). + Default is F.mse_loss(x, xq, reduction='none'). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + instead of using multipliers. This is specifically for NVFP4 + per-block quantization where scales are stored in FP8 format. """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax @@ -65,6 +69,13 @@ def __init__( self._error_func = error_func self._losses_sum = [None] * self._num_steps self._candidate_amaxs = [None] * self._num_steps + self._fp8_scale_sweep = fp8_scale_sweep + if fp8_scale_sweep: + # For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values + # (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN) + self._num_steps = 126 + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps self._amax = None @@ -83,14 +94,40 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device - ) + + if self._fp8_scale_sweep: + global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) + + # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn) + # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32 + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + + # Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + fp8_values_valid = fp8_values[valid_mask] + + candidates = fp8_values_valid / 448.0 + + print( + f"FP8 scale sweep: trying {len(candidates)} valid FP8 E4M3 multipliers (out of 128 total)" + ) + print( + f"Multiplier range: {candidates.min().item():.6e} to {candidates.max().item():.6e}" + ) + else: + candidates = torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + ) # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) - for step, multiplier in enumerate(multipliers): - candidate_amax = self._initial_amax * multiplier + for step, candidate in enumerate(candidates): + if self._fp8_scale_sweep: + candidate_amax = global_amax_expanded * candidate + else: + candidate_amax = self._initial_amax * candidate xq = self._quant_func(x, candidate_amax) if self._error_func is not None: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 2772e8138..3b6303af8 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -411,6 +411,28 @@ }, } +NVFP4_WEIGHT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + # "step_size": 0.5, + # "start_multiplier": 0.25, + # "stop_multiplier": 2.0, + "fp8_scale_sweep": True, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -1040,6 +1062,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig): reconstruction error of a tensor after uniform Q→DQ: s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} + + When fp8_scale_sweep is enabled, step_size is ignored. """ method: Literal["mse"] = ModeloptField("mse") @@ -1066,6 +1090,16 @@ class MseCalibConfig(QuantizeAlgorithmConfig): description="Ending multiplier for amax search range (multiplies initial amax).", ) + fp8_scale_sweep: bool | None = ModeloptField( + default=False, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep over all 128 possible FP8 E4M3 scale values " + "for NVFP4 per-block quantization instead of using multipliers. " + "This is specifically designed for optimizing the FP8-quantized per-block scales " + "in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and " + "stop_multiplier are ignored for NVFP4 per-block quantizers.", + ) + distributed_sync: bool | None = ModeloptField( default=True, title="Whether to sync the amax across the distributed processes.", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b8461a080..da284e25f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -205,6 +205,7 @@ def mse_calibrate( step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = False, ): """Calibrate the model using MSE-based amax search. @@ -220,6 +221,10 @@ def mse_calibrate( step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization instead of using multipliers. + This is specifically designed for optimizing the FP8-quantized + per-block scales in NVFP4 format (default: False). See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -260,6 +265,13 @@ def quant_func(x, amax, quantizer=module): return xq + is_nvfp4_per_block = ( + fp8_scale_sweep + and module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes.get("scale_bits") == (4, 3) + ) + # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, @@ -268,6 +280,7 @@ def quant_func(x, amax, quantizer=module): start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, + fp8_scale_sweep=is_nvfp4_per_block, ) # Identify weight quantizers by checking if they have corresponding weight parameters diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 9d82c1082..f789d0a56 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -70,7 +70,8 @@ mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, - NVFP4_WEIGHT_ACT_MSE_CFG, + mtq.NVFP4_WEIGHT_ACT_MSE_CFG, + mtq.NVFP4_WEIGHT_MSE_CFG, ], ) def test_quantize(model_cls, config): @@ -87,7 +88,8 @@ def test_quantize(model_cls, config): mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, - NVFP4_WEIGHT_ACT_MSE_CFG, + mtq.NVFP4_WEIGHT_ACT_MSE_CFG, + mtq.NVFP4_WEIGHT_MSE_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 5e5546512..efccec4c4 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -526,3 +526,70 @@ def quant_func(x, amax): assert a_best.numel() == 2 assert torch.all(torch.isfinite(a_best)) assert torch.all(a_best > 0) + + def test_fp8_scale_sweep_with_fixed_values_and_reset(self): + """Test FP8 scale sweep with fixed hand-written values and reset functionality.""" + x = torch.full((100,), 2.0, dtype=torch.float32) + x[0] = 20.0 + + initial_amax = torch.tensor(20.0) + + quant_cfg = QuantizerAttributeConfig(num_bits=(4, 3), axis=None, unsigned=False) + tq = TensorQuantizer(quant_attribute_cfg=quant_cfg, amax=initial_amax) + + def quant_func(x, amax): + original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None + was_quant_enabled = tq._if_quant + was_calib_enabled = tq._if_calib + + tq._amax = amax + tq._if_quant = True + tq._if_calib = False + + with enable_fake_quant(tq): + xq = tq(x) + + if original_amax is not None: + tq._amax = original_amax + tq._if_quant = was_quant_enabled + tq._if_calib = was_calib_enabled + return xq + + cal = calib.MseCalibrator( + amax=initial_amax, + quant_func=quant_func, + fp8_scale_sweep=True, + ) + + assert cal._num_steps == 126 + + cal.collect(x) + + a_best = cal.compute_amax() + + assert torch.isfinite(a_best), "Optimal amax should be finite" + assert a_best > 0, "Optimal amax should be positive" + assert a_best <= initial_amax, "Optimal amax should not exceed initial amax" + + # FP8 scale sweep uses global_amax * fp8_multiplier where fp8_multiplier + # ranges from ~4.36e-06 to 1.0. For mostly 2.0 values with one 20.0 outlier, + # the optimal amax should be somewhere between these extremes + assert a_best >= initial_amax * 1e-6, "Optimal amax should not be unreasonably small" + + a_best_value = a_best.item() + + cal.reset() + + a_after_reset = cal.compute_amax() + assert a_after_reset is None, "After reset, compute_amax should return None" + + assert cal._num_steps == 126, "After reset, num_steps should still be 126" + + cal.collect(x) + a_best_after_reset = cal.compute_amax() + + assert torch.isfinite(a_best_after_reset), "Should be able to compute amax after reset" + assert a_best_after_reset > 0, "Amax after reset should be positive" + assert abs(a_best_after_reset.item() - a_best_value) < 1e-6, ( + "Amax after reset should match original value with same data" + ) From d28b5951e5370aa402dc2174f602311006491edf Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 15 Jan 2026 19:53:06 +0000 Subject: [PATCH 2/3] minor clean up Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 7 ------ modelopt/torch/quantization/config.py | 5 +---- .../torch/quantization/test_quantize_cuda.py | 22 ++----------------- 3 files changed, 3 insertions(+), 31 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index d99d27525..4381d54f7 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -109,13 +109,6 @@ def collect(self, x: torch.Tensor): fp8_values_valid = fp8_values[valid_mask] candidates = fp8_values_valid / 448.0 - - print( - f"FP8 scale sweep: trying {len(candidates)} valid FP8 E4M3 multipliers (out of 128 total)" - ) - print( - f"Multiplier range: {candidates.min().item():.6e} to {candidates.max().item():.6e}" - ) else: candidates = torch.linspace( self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3b6303af8..e006fc2dc 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -411,7 +411,7 @@ }, } -NVFP4_WEIGHT_MSE_CFG = { +NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { "quant_cfg": { "*weight_quantizer": { "num_bits": (2, 1), @@ -426,9 +426,6 @@ }, "algorithm": { "method": "mse", - # "step_size": 0.5, - # "start_multiplier": 0.25, - # "stop_multiplier": 2.0, "fp8_scale_sweep": True, }, } diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index f789d0a56..95fec28a8 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -28,24 +28,6 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.extensions import get_cuda_ext_mx -NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - }, - "algorithm": "mse", -} - @pytest.mark.parametrize("model_cls", [SimpleLinear, SimpleConv, SimpleConvLinear]) @pytest.mark.parametrize( @@ -71,7 +53,7 @@ mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.NVFP4_WEIGHT_ACT_MSE_CFG, - mtq.NVFP4_WEIGHT_MSE_CFG, + mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, ], ) def test_quantize(model_cls, config): @@ -89,7 +71,7 @@ def test_quantize(model_cls, config): mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.NVFP4_WEIGHT_ACT_MSE_CFG, - mtq.NVFP4_WEIGHT_MSE_CFG, + mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") From 75e3ccd92ba0930446acb1a49381c2d656b60d62 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 15 Jan 2026 19:57:53 +0000 Subject: [PATCH 3/3] minor Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index da284e25f..2575b5913 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -269,6 +269,7 @@ def quant_func(x, amax, quantizer=module): fp8_scale_sweep and module.is_static_block_quant and module._num_bits == (2, 1) + and module._block_sizes is not None and module._block_sizes.get("scale_bits") == (4, 3) )