Skip to content

Commit d98d53c

Browse files
committed
FITCompress for all torch models, BatchNorm1d for Pytorch models, ebops with pruning mask and quantization effect for non-hgq use as metric
1 parent c42efb3 commit d98d53c

7 files changed

Lines changed: 356 additions & 214 deletions

File tree

src/pquant/core/keras/layers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,10 @@ def ebops(self, include_mask=False):
363363
if include_mask:
364364
mask = self.handle_transpose(self.pruning_layer.get_hard_mask(), self.weight_transpose_back, do_transpose=True)
365365
bw_ker = bw_ker * mask
366+
_, _, f = self.get_weight_quantization_bits()
367+
quantization_step_size = 2 ** (-f - 1)
368+
step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype)
369+
bw_ker = bw_ker * step_size_mask
366370
if self.parallelization_factor < 0:
367371
ebops = ops.sum(
368372
ops.depthwise_conv(
@@ -535,6 +539,10 @@ def ebops(self, include_mask=False):
535539
if include_mask:
536540
mask = self.handle_transpose(self.pruning_layer.get_hard_mask(), self.weight_transpose_back, do_transpose=True)
537541
bw_ker = bw_ker * mask
542+
_, _, f = self.get_weight_quantization_bits()
543+
quantization_step_size = 2 ** (-f - 1)
544+
step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype)
545+
bw_ker = bw_ker * step_size_mask
538546
if self.parallelization_factor < 0:
539547
ebops = ops.sum(
540548
ops.conv(
@@ -769,6 +777,10 @@ def ebops(self, include_mask=False):
769777
if include_mask:
770778
mask = self.handle_transpose(self.pruning_layer.get_hard_mask(), self.weight_transpose_back, do_transpose=True)
771779
bw_ker = bw_ker * mask
780+
_, _, f = self.get_weight_quantization_bits()
781+
quantization_step_size = 2 ** (-f - 1)
782+
step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype)
783+
bw_ker = bw_ker * step_size_mask
772784
if self.parallelization_factor < 0:
773785
ebops = ops.sum(
774786
ops.conv(
@@ -913,6 +925,10 @@ def ebops(self, include_mask=False):
913925
if include_mask:
914926
mask = self.handle_transpose(self.pruning_layer.get_hard_mask(), self.weight_transpose_back, do_transpose=True)
915927
bw_ker = bw_ker * mask
928+
_, _, f = self.get_weight_quantization_bits()
929+
quantization_step_size = 2 ** (-f - 1)
930+
step_size_mask = ops.cast((ops.abs(self._kernel) > quantization_step_size), self._kernel.dtype)
931+
bw_ker = bw_ker * step_size_mask
916932
ebops = ops.sum(ops.matmul(bw_inp, bw_ker))
917933
ebops = ebops * self.n_parallel / self.parallelization_factor
918934
if self.use_bias:
@@ -2163,7 +2179,7 @@ def get_ebops(model):
21632179
ebops = 0
21642180
for m in model.layers:
21652181
if isinstance(m, (PQWeightBiasBase)):
2166-
ebops += m.ebops(include_mask=True)
2182+
ebops += m.ebops(include_mask=m.enable_pruning)
21672183
elif isinstance(m, (PQAvgPoolBase, PQBatchNormalization, PQActivation)):
21682184
ebops += m.ebops()
21692185
return ebops

src/pquant/core/torch/activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def check_is_built(self, input_shape):
7979
if self.built:
8080
return
8181
self.built = True
82-
self.input_shape = input_shape
82+
self.input_shape = (1,) + input_shape[1:]
8383
self.output_quantizer = Quantizer(
8484
k=self.k_output,
8585
i=self.i_output,

0 commit comments

Comments
 (0)