Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,16 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0
if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about in line 265, do this:

if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None:
weight_quantizer.reset_amax()
enable_stats_collection(weight_quantizer)
weight_quantizer(weight)
finish_stats_collection(weight_quantizer)

So all weights have amax and you don't need the other code changes.

weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0
else:
# Compute from weight if amax not set
from ..quantization.utils import reduce_amax

weight_scaling_factor_2 = reduce_amax(weight).float() / 448.0
else:
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
weight_quantizer
weight_quantizer, weight
)
return NVFP4QTensor.get_weights_scaling_factor(
weight,
Expand All @@ -304,11 +310,19 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
]:
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
weight = getattr(module, weight_name)
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer, weight)
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
# This is because the kernel dequantizes weight to fp8, which is in range 448.
return weight_quantizer._amax.float() / 448.0
if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None:
return weight_quantizer._amax.float() / 448.0
else:
# Compute from weight if amax not set
from ..quantization.utils import reduce_amax

weight = getattr(module, weight_name)
return reduce_amax(weight).float() / 448.0

# SequentialQuantizer is required
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:
Expand Down
25 changes: 20 additions & 5 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,26 @@ def get_e2m1_bounds(cls, device):
return cls.e2m1_bounds_on_device[device]

@classmethod
def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
"""Returns per tensor weight scaling factor from the weight_quantizer amax."""
# Assert that weight_quantizer has attribute amax
assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax"
return weight_quantizer._amax.float() / (6.0 * 448.0)
def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer, weight=None):
"""Returns per tensor weight scaling factor from the weight_quantizer amax.

Args:
weight_quantizer: The weight quantizer module
weight: Optional weight tensor to compute amax from if not set on quantizer
"""
# Check if weight_quantizer has amax attribute and it's not None
if hasattr(weight_quantizer, "_amax") and weight_quantizer._amax is not None:
return weight_quantizer._amax.float() / (6.0 * 448.0)

# Fallback: compute amax from weight if provided
if weight is not None:
return cls.get_weights_scaling_factor_2(weight)

# If neither amax nor weight is available, raise an error
raise ValueError(
"Weight quantizer does not have attribute amax and no weight tensor provided. "
"Cannot compute scaling factor."
)

@classmethod
def get_weights_scaling_factor(
Expand Down