diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..224164765 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -231,6 +231,25 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: return scaling_factor +def _ensure_weight_quantizer_calibrated( + weight_quantizer: TensorQuantizer, weight: torch.Tensor +) -> None: + """Calibrate weight quantizer if amax is not set. + + This is a lazy calibration pattern used during export when weight quantizers + may not have been calibrated during the main calibration phase. + + Args: + weight_quantizer: The weight quantizer to calibrate + weight: The weight tensor to use for calibration + """ + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + weight_quantizer.reset_amax() + enable_stats_collection(weight_quantizer) + weight_quantizer(weight) + finish_stats_collection(weight_quantizer) + + def get_activation_scaling_factor( module: nn.Module, input_quantizer_name: str = "input_quantizer" ) -> torch.Tensor: @@ -272,6 +291,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ]: + # Calibrate weight quantizer if amax is not set + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. @@ -304,10 +326,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, ]: + # Calibrate weight quantizer if amax is not set + weight = getattr(module, weight_name) + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. + # Calibrate weight quantizer if amax is not set + weight = getattr(module, weight_name) + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required