From 842b6e5c9d57d0550db4c49d46ada8122b2dde28 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Wed, 14 Jan 2026 00:38:12 -0800 Subject: [PATCH 1/4] fix a nvfp4 quantization amax attribute error Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/quant_utils.py | 20 ++++++++++++--- .../quantization/qtensor/nvfp4_tensor.py | 25 +++++++++++++++---- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..60b359978 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -275,10 +275,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. - weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 + if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: + weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 + else: + # Compute from weight if amax not set + from ..utils import reduce_amax + weight_scaling_factor_2 = reduce_amax(weight).float() / 448.0 else: weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( - weight_quantizer + weight_quantizer, weight ) return NVFP4QTensor.get_weights_scaling_factor( weight, @@ -304,11 +309,18 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, ]: - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + weight = getattr(module, weight_name) + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer, weight) elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. - return weight_quantizer._amax.float() / 448.0 + if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: + return weight_quantizer._amax.float() / 448.0 + else: + # Compute from weight if amax not set + from ..quantization.utils import reduce_amax + weight = getattr(module, weight_name) + return reduce_amax(weight).float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 2ff1b17e9..60ec265b9 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -53,11 +53,26 @@ def get_e2m1_bounds(cls, device): return cls.e2m1_bounds_on_device[device] @classmethod - def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): - """Returns per tensor weight scaling factor from the weight_quantizer amax.""" - # Assert that weight_quantizer has attribute amax - assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" - return weight_quantizer._amax.float() / (6.0 * 448.0) + def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer, weight=None): + """Returns per tensor weight scaling factor from the weight_quantizer amax. + + Args: + weight_quantizer: The weight quantizer module + weight: Optional weight tensor to compute amax from if not set on quantizer + """ + # Check if weight_quantizer has amax attribute and it's not None + if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: + return weight_quantizer._amax.float() / (6.0 * 448.0) + + # Fallback: compute amax from weight if provided + if weight is not None: + return cls.get_weights_scaling_factor_2(weight) + + # If neither amax nor weight is available, raise an error + raise ValueError( + "Weight quantizer does not have attribute amax and no weight tensor provided. " + "Cannot compute scaling factor." + ) @classmethod def get_weights_scaling_factor( From ab4fc491e3532a93e00726a13f59fcb7da9fa385 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Wed, 14 Jan 2026 00:38:41 -0800 Subject: [PATCH 2/4] fix a nvfp4 quantization amax attribute error Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/quant_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 60b359978..af940ba93 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -280,6 +280,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> else: # Compute from weight if amax not set from ..utils import reduce_amax + weight_scaling_factor_2 = reduce_amax(weight).float() / 448.0 else: weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( @@ -319,6 +320,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") else: # Compute from weight if amax not set from ..quantization.utils import reduce_amax + weight = getattr(module, weight_name) return reduce_amax(weight).float() / 448.0 From 066667898cd6d8143a9b55f3c7589b0669e55589 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Wed, 14 Jan 2026 00:42:55 -0800 Subject: [PATCH 3/4] fix a nvfp4 quantization amax attribute error Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index af940ba93..526c1f21f 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -279,7 +279,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 else: # Compute from weight if amax not set - from ..utils import reduce_amax + from ..quantization.utils import reduce_amax weight_scaling_factor_2 = reduce_amax(weight).float() / 448.0 else: From 3f10f742d334ab78ac7058fae7ff9c9cfb7e2d01 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 15 Jan 2026 11:17:06 -0800 Subject: [PATCH 4/4] address reviews Signed-off-by: Zhiyu Cheng --- modelopt/torch/export/quant_utils.py | 48 ++++++++++++------- .../quantization/qtensor/nvfp4_tensor.py | 25 ++-------- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 526c1f21f..224164765 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -231,6 +231,25 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: return scaling_factor +def _ensure_weight_quantizer_calibrated( + weight_quantizer: TensorQuantizer, weight: torch.Tensor +) -> None: + """Calibrate weight quantizer if amax is not set. + + This is a lazy calibration pattern used during export when weight quantizers + may not have been calibrated during the main calibration phase. + + Args: + weight_quantizer: The weight quantizer to calibrate + weight: The weight tensor to use for calibration + """ + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + weight_quantizer.reset_amax() + enable_stats_collection(weight_quantizer) + weight_quantizer(weight) + finish_stats_collection(weight_quantizer) + + def get_activation_scaling_factor( module: nn.Module, input_quantizer_name: str = "input_quantizer" ) -> torch.Tensor: @@ -272,19 +291,16 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ]: + # Calibrate weight quantizer if amax is not set + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. - if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: - weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 - else: - # Compute from weight if amax not set - from ..quantization.utils import reduce_amax - - weight_scaling_factor_2 = reduce_amax(weight).float() / 448.0 + weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 else: weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( - weight_quantizer, weight + weight_quantizer ) return NVFP4QTensor.get_weights_scaling_factor( weight, @@ -310,19 +326,17 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, ]: + # Calibrate weight quantizer if amax is not set weight = getattr(module, weight_name) - return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer, weight) + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) + return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. - if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: - return weight_quantizer._amax.float() / 448.0 - else: - # Compute from weight if amax not set - from ..quantization.utils import reduce_amax - - weight = getattr(module, weight_name) - return reduce_amax(weight).float() / 448.0 + # Calibrate weight quantizer if amax is not set + weight = getattr(module, weight_name) + _ensure_weight_quantizer_calibrated(weight_quantizer, weight) + return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 60ec265b9..2ff1b17e9 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -53,26 +53,11 @@ def get_e2m1_bounds(cls, device): return cls.e2m1_bounds_on_device[device] @classmethod - def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer, weight=None): - """Returns per tensor weight scaling factor from the weight_quantizer amax. - - Args: - weight_quantizer: The weight quantizer module - weight: Optional weight tensor to compute amax from if not set on quantizer - """ - # Check if weight_quantizer has amax attribute and it's not None - if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None: - return weight_quantizer._amax.float() / (6.0 * 448.0) - - # Fallback: compute amax from weight if provided - if weight is not None: - return cls.get_weights_scaling_factor_2(weight) - - # If neither amax nor weight is available, raise an error - raise ValueError( - "Weight quantizer does not have attribute amax and no weight tensor provided. " - "Cannot compute scaling factor." - ) + def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): + """Returns per tensor weight scaling factor from the weight_quantizer amax.""" + # Assert that weight_quantizer has attribute amax + assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" + return weight_quantizer._amax.float() / (6.0 * 448.0) @classmethod def get_weights_scaling_factor(