diff --git a/src/pruna/algorithms/torchao.py b/src/pruna/algorithms/torchao.py index 453be060..cd4bed38 100644 --- a/src/pruna/algorithms/torchao.py +++ b/src/pruna/algorithms/torchao.py @@ -323,12 +323,12 @@ def import_algorithm_packages(self) -> Dict[str, Any]: The algorithm packages. """ from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, quantize_, ) from torchao.quantization.quant_api import PerRow @@ -336,13 +336,13 @@ def import_algorithm_packages(self) -> Dict[str, Any]: _is_linear = importlib.import_module("torchao.quantization.quant_api")._is_linear return dict( quantize=quantize_, - int4dq=int8_dynamic_activation_int4_weight(), - int4wo=int4_weight_only(), - int8dq=int8_dynamic_activation_int8_weight(), - int8wo=int8_weight_only(), - fp8wo=float8_weight_only(), - fp8dq=float8_dynamic_activation_float8_weight(), - fp8dqrow=float8_dynamic_activation_float8_weight( + int4dq=Int8DynamicActivationInt4WeightConfig(), + int4wo=Int4WeightOnlyConfig(), + int8dq=Int8DynamicActivationInt8WeightConfig(), + int8wo=Int8WeightOnlyConfig(), + fp8wo=Float8WeightOnlyConfig(), + fp8dq=Float8DynamicActivationFloat8WeightConfig(), + fp8dqrow=Float8DynamicActivationFloat8WeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=PerRow(),