Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,24 @@ def __init__(
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
fp8_scale_sweep: bool = False,
):
"""Initialize MSE calibrator.

Args:
amax: Initial amax value (required).
axis: Quantization axis. None means per-tensor quantization.
step_size: Step size for amax search. The number of steps is computed as
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier: Starting multiplier for amax search.
stop_multiplier: Ending multiplier for amax search.
quant_func: Function that quantizes input tensor given an amax value.
Should have signature: quant_func(x, amax) -> quantized_x.
Should have signature: quant_func(x, amax) -> quantized_x.
error_func: Function to compute error between x and xq.
Default is F.mse_loss(x, xq, reduction='none').
Default is F.mse_loss(x, xq, reduction='none').
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
instead of using multipliers. This is specifically for NVFP4
per-block quantization where scales are stored in FP8 format.
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
Expand All @@ -65,6 +69,13 @@ def __init__(
self._error_func = error_func
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._fp8_scale_sweep = fp8_scale_sweep
if fp8_scale_sweep:
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
self._num_steps = 126
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None

Expand All @@ -83,14 +94,33 @@ def collect(self, x: torch.Tensor):
x = x.detach().to(dtype=torch.float32)

device = x.device
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)

if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)

# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()

# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values_valid = fp8_values[valid_mask]

candidates = fp8_values_valid / 448.0
else:
candidates = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
# Get reduce axis for per-channel quantization
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

for step, multiplier in enumerate(multipliers):
candidate_amax = self._initial_amax * multiplier
for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
candidate_amax = global_amax_expanded * candidate
else:
candidate_amax = self._initial_amax * candidate
xq = self._quant_func(x, candidate_amax)

if self._error_func is not None:
Expand Down
31 changes: 31 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,25 @@
},
}

NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -1040,6 +1059,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
reconstruction error of a tensor after uniform Q→DQ:

s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}

When fp8_scale_sweep is enabled, step_size is ignored.
"""

method: Literal["mse"] = ModeloptField("mse")
Expand All @@ -1066,6 +1087,16 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
description="Ending multiplier for amax search range (multiplies initial amax).",
)

fp8_scale_sweep: bool | None = ModeloptField(
default=False,
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
"for NVFP4 per-block quantization instead of using multipliers. "
"This is specifically designed for optimizing the FP8-quantized per-block scales "
"in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and "
"stop_multiplier are ignored for NVFP4 per-block quantizers.",
)

distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def mse_calibrate(
step_size: float = 0.1,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
):
"""Calibrate the model using MSE-based amax search.

Expand All @@ -220,6 +221,10 @@ def mse_calibrate(
step_size: Step size for amax search (default: 0.1).
start_multiplier: Starting multiplier for amax search (default: 0.25).
stop_multiplier: Ending multiplier for amax search (default: 4.0).
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).

See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
Expand Down Expand Up @@ -260,6 +265,13 @@ def quant_func(x, amax, quantizer=module):

return xq

is_nvfp4_per_block = (
fp8_scale_sweep
and module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes.get("scale_bits") == (4, 3)
)

# Create MSE calibrator with quant_func
module._calibrator = MseCalibrator(
amax=initial_amax,
Expand All @@ -268,6 +280,7 @@ def quant_func(x, amax, quantizer=module):
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=quant_func,
fp8_scale_sweep=is_nvfp4_per_block,
)

# Identify weight quantizers by checking if they have corresponding weight parameters
Expand Down
24 changes: 4 additions & 20 deletions tests/gpu/torch/quantization/test_quantize_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,6 @@
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.extensions import get_cuda_ext_mx

NVFP4_WEIGHT_ACT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
},
"algorithm": "mse",
}


@pytest.mark.parametrize("model_cls", [SimpleLinear, SimpleConv, SimpleConvLinear])
@pytest.mark.parametrize(
Expand All @@ -70,7 +52,8 @@
mtq.MXINT8_DEFAULT_CFG,
mtq.NVFP4_KV_ROTATE_CFG,
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
NVFP4_WEIGHT_ACT_MSE_CFG,
mtq.NVFP4_WEIGHT_ACT_MSE_CFG,
mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
],
)
def test_quantize(model_cls, config):
Expand All @@ -87,7 +70,8 @@ def test_quantize(model_cls, config):
mtq.MXINT8_DEFAULT_CFG,
mtq.NVFP4_KV_ROTATE_CFG,
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
NVFP4_WEIGHT_ACT_MSE_CFG,
mtq.NVFP4_WEIGHT_ACT_MSE_CFG,
mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
]:
if get_cuda_ext_mx() is None:
pytest.skip("cuda_ext_mx is not available")
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/torch/quantization/test_mse_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,70 @@ def quant_func(x, amax):
assert a_best.numel() == 2
assert torch.all(torch.isfinite(a_best))
assert torch.all(a_best > 0)

def test_fp8_scale_sweep_with_fixed_values_and_reset(self):
"""Test FP8 scale sweep with fixed hand-written values and reset functionality."""
x = torch.full((100,), 2.0, dtype=torch.float32)
x[0] = 20.0

initial_amax = torch.tensor(20.0)

quant_cfg = QuantizerAttributeConfig(num_bits=(4, 3), axis=None, unsigned=False)
tq = TensorQuantizer(quant_attribute_cfg=quant_cfg, amax=initial_amax)

def quant_func(x, amax):
original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None
was_quant_enabled = tq._if_quant
was_calib_enabled = tq._if_calib

tq._amax = amax
tq._if_quant = True
tq._if_calib = False

with enable_fake_quant(tq):
xq = tq(x)

if original_amax is not None:
tq._amax = original_amax
tq._if_quant = was_quant_enabled
tq._if_calib = was_calib_enabled
return xq

cal = calib.MseCalibrator(
amax=initial_amax,
quant_func=quant_func,
fp8_scale_sweep=True,
)

assert cal._num_steps == 126

cal.collect(x)

a_best = cal.compute_amax()

assert torch.isfinite(a_best), "Optimal amax should be finite"
assert a_best > 0, "Optimal amax should be positive"
assert a_best <= initial_amax, "Optimal amax should not exceed initial amax"

# FP8 scale sweep uses global_amax * fp8_multiplier where fp8_multiplier
# ranges from ~4.36e-06 to 1.0. For mostly 2.0 values with one 20.0 outlier,
# the optimal amax should be somewhere between these extremes
assert a_best >= initial_amax * 1e-6, "Optimal amax should not be unreasonably small"

a_best_value = a_best.item()

cal.reset()

a_after_reset = cal.compute_amax()
assert a_after_reset is None, "After reset, compute_amax should return None"

assert cal._num_steps == 126, "After reset, num_steps should still be 126"

cal.collect(x)
a_best_after_reset = cal.compute_amax()

assert torch.isfinite(a_best_after_reset), "Should be able to compute amax after reset"
assert a_best_after_reset > 0, "Amax after reset should be positive"
assert abs(a_best_after_reset.item() - a_best_value) < 1e-6, (
"Amax after reset should match original value with same data"
)