Skip to content

Commit 2546c39

Browse files
author
RJ Ascani
committed
Add batch size validation and test coverage for depthwise conv2d
CMSIS-NN arm_depthwise_conv_wrapper_s8 only supports batch size 1. Add validation in both AOT pass (fail during compilation) and runtime (defensive check). Add 6 test cases covering edge cases: - Combined stride/padding/bias - 1x1 kernels (common in mobile networks) - Higher depth_multiplier (4) - Asymmetric kernels (1x3) - Asymmetric stride/padding - Larger kernels (5x5) Fix depthwise_conv2d_stride test to use batch size 1.
1 parent 2256067 commit 2546c39

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ bool validate_depthwise_conv2d_arguments(
3939
return false;
4040
}
4141

42+
// CMSIS-NN depthwise convolution only supports batch size of 1
43+
if (input.size(0) != 1) {
44+
ET_LOG(
45+
Error,
46+
"quantized_depthwise_conv2d_out: CMSIS-NN only supports batch size 1, got %zd",
47+
input.size(0));
48+
context.fail(Error::InvalidArgument);
49+
return false;
50+
}
51+
4252
// Validate weight is in IHWO layout: [1, H, W, C_OUT]
4353
if (weight.size(0) != 1) {
4454
ET_LOG(

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ def _get_convolution_replacement(self, node) -> int:
207207
output_channels = weight_tensor.shape[0]
208208
input_channels = groups # For depthwise, groups == input_channels
209209

210+
# CMSIS-NN depthwise convolution only supports batch size of 1
211+
input_tensor = get_first_fake_tensor(x)
212+
batch_size = input_tensor.shape[0]
213+
if batch_size != 1:
214+
raise ValueError(
215+
f"Depthwise conv: CMSIS-NN only supports batch size 1, got {batch_size}"
216+
)
217+
210218
if output_channels % input_channels != 0:
211219
raise ValueError(
212220
f"Depthwise conv: output_channels ({output_channels}) must be "

backends/cortex_m/test/ops/test_conv.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def forward(self, x):
238238
"depthwise_conv2d_stride": McuTestCase(
239239
model=CortexMDepthwiseConv2D(4, 4, 3, stride=2, groups=4),
240240
example_inputs=(
241-
ramp_tensor(-50, 50, (2, 4, 8, 8)).to(memory_format=torch.channels_last),
241+
ramp_tensor(-50, 50, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
242242
),
243243
),
244244
"depthwise_conv2d_padding": McuTestCase(
@@ -253,6 +253,42 @@ def forward(self, x):
253253
ramp_tensor(-10, 10, (1, 3, 6, 6)).to(memory_format=torch.channels_last),
254254
),
255255
),
256+
"depthwise_conv2d_stride_padding_bias": McuTestCase(
257+
model=CortexMDepthwiseConv2DBias(4, 4, 3, stride=2, padding=1, groups=4),
258+
example_inputs=(
259+
ramp_tensor(0, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
260+
),
261+
),
262+
"depthwise_conv2d_1x1": McuTestCase(
263+
model=CortexMDepthwiseConv2D(4, 8, 1, groups=4),
264+
example_inputs=(
265+
ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
266+
),
267+
),
268+
"depthwise_conv2d_multiplier_4": McuTestCase(
269+
model=CortexMDepthwiseConv2D(2, 8, 3, groups=2),
270+
example_inputs=(
271+
ramp_tensor(0, 10, (1, 2, 8, 8)).to(memory_format=torch.channels_last),
272+
),
273+
),
274+
"depthwise_conv2d_asymmetric_kernel": McuTestCase(
275+
model=CortexMDepthwiseConv2D(4, 4, (1, 3), groups=4),
276+
example_inputs=(
277+
ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
278+
),
279+
),
280+
"depthwise_conv2d_asymmetric_stride": McuTestCase(
281+
model=CortexMDepthwiseConv2D(3, 3, 3, stride=(2, 1), padding=(1, 0), groups=3),
282+
example_inputs=(
283+
ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last),
284+
),
285+
),
286+
"depthwise_conv2d_5x5": McuTestCase(
287+
model=CortexMDepthwiseConv2D(4, 4, 5, padding=2, groups=4),
288+
example_inputs=(
289+
ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
290+
),
291+
),
256292
}
257293

258294

0 commit comments

Comments
 (0)