diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9bdb3871a2..a942ee80a8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -467,9 +467,10 @@ def _quantize_affine_no_dtype_cast( # with numel=0 which we handle by unifying the two zero_point = None - quant = torch.clamp( - _Round.apply(input * (1.0 / scale)) + zero_point, quant_min, quant_max - ) + quant: torch.Tensor = _Round.apply(input * (1.0 / scale)) + if zero_point is not None: + quant = quant + zero_point + quant = torch.clamp(quant, quant_min, quant_max) quant = quant.view(original_shape) return quant