Skip to content

Commit 0b16701

Browse files
committed
optional dependencies, fix depthwise bug
1 parent 42df073 commit 0b16701

3 files changed

Lines changed: 16 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ classifiers = [
2323
"Topic :: Software Development :: Libraries :: Python Modules",
2424
]
2525
dynamic = [ "version" ]
26-
dependencies = [ "keras>=3", "pyyaml>=6.0.1", "quantizers>=1.1", "torch>=2.1", "pydantic>=2.0"]
26+
dependencies = [ "hgq2", "keras>=3", "optuna", "pydantic>=2", "pyyaml>=6.0.1", "quantizers>=1.1" ]
27+
optional-dependencies.all = [ "pytest>=8.4", "tensorflow>=2.17,<=2.20", "torch>=2.1" ]
28+
optional-dependencies.tensorflow = [ "tensorflow>=2.17,<=2.20" ]
29+
optional-dependencies.test = [ "pytest>=8.4" ]
30+
optional-dependencies.torch = [ "torch>=2.1" ]
2731
urls.repository = "https://github.com/nroope/PQuant"
2832

2933
[tool.setuptools]

src/pquant/core/keras/layers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ReLU,
1717
SeparableConv2D,
1818
)
19+
from keras.src.layers.input_spec import InputSpec
1920
from keras.src.ops.operation_utils import compute_pooling_output_shape
2021

2122
from pquant.core.keras.activations import PQActivation
@@ -276,10 +277,8 @@ def __init__(
276277
self.use_bias = use_bias
277278
self.strides = strides
278279
self.dilation_rate = dilation_rate
279-
# self.weight_transpose = (2, 3, 0, 1)
280-
# self.weight_transpose_back = (2, 3, 1, 0)
281-
self.weight_transpose = (3, 2, 0, 1)
282-
self.weight_transpose_back = (2, 3, 1, 0)
280+
self.weight_transpose = (2, 3, 0, 1)
281+
self.weight_transpose_back = (2, 3, 0, 1)
283282
self.data_transpose = (0, 3, 1, 2)
284283
self.do_transpose_data = self.data_format == "channels_last"
285284
self._weight = None
@@ -288,10 +287,12 @@ def __init__(
288287
def build(self, input_shape):
289288
super().build(input_shape)
290289
if self.data_format == "channels_last":
290+
channel_axis = -1
291291
input_channel = input_shape[-1]
292292
else:
293+
channel_axis = 1
293294
input_channel = input_shape[1]
294-
295+
self.input_spec = InputSpec(min_ndim=self.rank + 2, axes={channel_axis: input_channel})
295296
depthwise_shape = self.kernel_size + (
296297
input_channel,
297298
self.depth_multiplier,

tests/test_keras_compression_layers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,15 +735,18 @@ def test_ap_conv1d_channels_last_transpose(config_ap, conv1d_input):
735735

736736

737737
def test_ap_depthwiseconv2d_channels_last_transpose(config_ap, conv2d_input):
738+
if keras.backend.image_data_format() == "channels_last":
739+
conv2d_input = ops.transpose(conv2d_input, (0, 3, 1, 2))
738740
keras.backend.set_image_data_format("channels_first")
739741
inp = ops.reshape(ops.linspace(0, 1, ops.size(conv2d_input)), conv2d_input.shape)
740742

741743
inputs = keras.Input(shape=inp.shape[1:])
742-
out = DepthwiseConv2D(KERNEL_SIZE, use_bias=False, padding="same")(inputs)
744+
out = DepthwiseConv2D(KERNEL_SIZE, use_bias=False, padding="same", data_format="channels_first")(inputs)
743745
model_cf = keras.Model(inputs=inputs, outputs=out, name="test_dwconv2d")
744746
model_cf = add_compression_layers(model_cf, config_ap, inp.shape)
745-
weight_cf = model_cf.layers[1]._kernel
746747

748+
weight_cf = model_cf.layers[1]._kernel
749+
model_cf.summary()
747750
post_pretrain_functions(model_cf, config_ap)
748751
model_cf(inp, training=True)
749752
model_cf(inp, training=True)

0 commit comments

Comments
 (0)