diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9bdb3871a2..826bb80675 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1514,7 +1514,7 @@ def choose_qparams_affine_with_min_max( @register_custom_op def _choose_qparams_affine( - input: Optional[torch.Tensor], + input: torch.Tensor, mapping_type: str, block_size: List[int], target_dtype: torch.dtype, @@ -1667,7 +1667,7 @@ def reshape_w(w): def _choose_qparams_gguf( - input: Optional[torch.Tensor], + input: torch.Tensor, block_size: List[int], target_dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: