@@ -363,6 +363,10 @@ def ebops(self, include_mask=False):
363363 if include_mask :
364364 mask = self .handle_transpose (self .pruning_layer .get_hard_mask (), self .weight_transpose_back , do_transpose = True )
365365 bw_ker = bw_ker * mask
366+ _ , _ , f = self .get_weight_quantization_bits ()
367+ quantization_step_size = 2 ** (- f - 1 )
368+ step_size_mask = ops .cast ((ops .abs (self ._kernel ) > quantization_step_size ), self ._kernel .dtype )
369+ bw_ker = bw_ker * step_size_mask
366370 if self .parallelization_factor < 0 :
367371 ebops = ops .sum (
368372 ops .depthwise_conv (
@@ -535,6 +539,10 @@ def ebops(self, include_mask=False):
535539 if include_mask :
536540 mask = self .handle_transpose (self .pruning_layer .get_hard_mask (), self .weight_transpose_back , do_transpose = True )
537541 bw_ker = bw_ker * mask
542+ _ , _ , f = self .get_weight_quantization_bits ()
543+ quantization_step_size = 2 ** (- f - 1 )
544+ step_size_mask = ops .cast ((ops .abs (self ._kernel ) > quantization_step_size ), self ._kernel .dtype )
545+ bw_ker = bw_ker * step_size_mask
538546 if self .parallelization_factor < 0 :
539547 ebops = ops .sum (
540548 ops .conv (
@@ -769,6 +777,10 @@ def ebops(self, include_mask=False):
769777 if include_mask :
770778 mask = self .handle_transpose (self .pruning_layer .get_hard_mask (), self .weight_transpose_back , do_transpose = True )
771779 bw_ker = bw_ker * mask
780+ _ , _ , f = self .get_weight_quantization_bits ()
781+ quantization_step_size = 2 ** (- f - 1 )
782+ step_size_mask = ops .cast ((ops .abs (self ._kernel ) > quantization_step_size ), self ._kernel .dtype )
783+ bw_ker = bw_ker * step_size_mask
772784 if self .parallelization_factor < 0 :
773785 ebops = ops .sum (
774786 ops .conv (
@@ -913,6 +925,10 @@ def ebops(self, include_mask=False):
913925 if include_mask :
914926 mask = self .handle_transpose (self .pruning_layer .get_hard_mask (), self .weight_transpose_back , do_transpose = True )
915927 bw_ker = bw_ker * mask
928+ _ , _ , f = self .get_weight_quantization_bits ()
929+ quantization_step_size = 2 ** (- f - 1 )
930+ step_size_mask = ops .cast ((ops .abs (self ._kernel ) > quantization_step_size ), self ._kernel .dtype )
931+ bw_ker = bw_ker * step_size_mask
916932 ebops = ops .sum (ops .matmul (bw_inp , bw_ker ))
917933 ebops = ebops * self .n_parallel / self .parallelization_factor
918934 if self .use_bias :
@@ -2163,7 +2179,7 @@ def get_ebops(model):
21632179 ebops = 0
21642180 for m in model .layers :
21652181 if isinstance (m , (PQWeightBiasBase )):
2166- ebops += m .ebops (include_mask = True )
2182+ ebops += m .ebops (include_mask = m . enable_pruning )
21672183 elif isinstance (m , (PQAvgPoolBase , PQBatchNormalization , PQActivation )):
21682184 ebops += m .ebops ()
21692185 return ebops
0 commit comments