1616 ReLU ,
1717 SeparableConv2D ,
1818)
19+ from keras .src .layers .input_spec import InputSpec
1920from keras .src .ops .operation_utils import compute_pooling_output_shape
2021
2122from pquant .core .keras .activations import PQActivation
@@ -276,10 +277,8 @@ def __init__(
276277 self .use_bias = use_bias
277278 self .strides = strides
278279 self .dilation_rate = dilation_rate
279- # self.weight_transpose = (2, 3, 0, 1)
280- # self.weight_transpose_back = (2, 3, 1, 0)
281- self .weight_transpose = (3 , 2 , 0 , 1 )
282- self .weight_transpose_back = (2 , 3 , 1 , 0 )
280+ self .weight_transpose = (2 , 3 , 0 , 1 )
281+ self .weight_transpose_back = (2 , 3 , 0 , 1 )
283282 self .data_transpose = (0 , 3 , 1 , 2 )
284283 self .do_transpose_data = self .data_format == "channels_last"
285284 self ._weight = None
@@ -288,10 +287,12 @@ def __init__(
288287 def build (self , input_shape ):
289288 super ().build (input_shape )
290289 if self .data_format == "channels_last" :
290+ channel_axis = - 1
291291 input_channel = input_shape [- 1 ]
292292 else :
293+ channel_axis = 1
293294 input_channel = input_shape [1 ]
294-
295+ self . input_spec = InputSpec ( min_ndim = self . rank + 2 , axes = { channel_axis : input_channel })
295296 depthwise_shape = self .kernel_size + (
296297 input_channel ,
297298 self .depth_multiplier ,
0 commit comments