Skip to content

Commit 2256067

Browse files
author
RJ Ascani
committed
Improve depthwise conv2d implementation with AOT optimizations and validations
Key changes: - Move depth_multiplier calculation from runtime to AOT pass (eliminates runtime division by computing depth_multiplier = output_channels / input_channels in the graph transformation pass) - Add critical defensive validations in validate_depthwise_conv2d_arguments(): * Validate IHWO weight layout (dimension 0 must be 1) * Validate dilation == 1 (CMSIS-NN constraint) * Validate depth_multiplier consistency with channel counts - Fix CMSIS-NN API usage: * Use arm_depthwise_conv_wrapper_s8_get_buffer_size() with correct parameters * Improve buffer allocation error handling with detailed error messages - Add _compute_depthwise_conv2d_output_shape() to read channels from correct dimension (dim 3 for IHWO layout vs dim 0 for OHWI) - Update operator schema to use depth_multiplier parameter instead of groups This ensures proper validation of CMSIS-NN constraints and moves computation to compile-time where possible.
1 parent 7de9f62 commit 2256067

File tree

4 files changed

+138
-54
lines changed

4 files changed

+138
-54
lines changed

backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*
2-
* Copyright 2025 Arm Limited and/or its affiliates.
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
34
*
45
* This source code is licensed under the BSD-style license found in the
56
* LICENSE file in the root directory of this source tree.
@@ -28,7 +29,7 @@ bool validate_depthwise_conv2d_arguments(
2829
const IntArrayRef& stride,
2930
const IntArrayRef& padding,
3031
const IntArrayRef& dilation,
31-
const int64_t groups,
32+
const int64_t depth_multiplier,
3233
const Tensor& requantize_multipliers,
3334
const Tensor& requantize_shifts) {
3435
if (input.dim() != kConvDim || weight.dim() != kConvDim ||
@@ -38,6 +39,28 @@ bool validate_depthwise_conv2d_arguments(
3839
return false;
3940
}
4041

42+
// Validate weight is in IHWO layout: [1, H, W, C_OUT]
43+
if (weight.size(0) != 1) {
44+
ET_LOG(
45+
Error,
46+
"quantized_depthwise_conv2d_out: weight dim 0 must be 1 for IHWO layout, got %zd",
47+
weight.size(0));
48+
context.fail(Error::InvalidArgument);
49+
return false;
50+
}
51+
52+
const int64_t weight_output_channels = weight.size(3);
53+
const int64_t output_channels = output.size(1);
54+
if (weight_output_channels != output_channels) {
55+
ET_LOG(
56+
Error,
57+
"quantized_depthwise_conv2d_out: weight output channels (%zd) must match output channels (%zd)",
58+
weight_output_channels,
59+
output_channels);
60+
context.fail(Error::InvalidArgument);
61+
return false;
62+
}
63+
4164
// Check for channels_last dim_order (NHWC: 0, 2, 3, 1)
4265
// Skip check if channels == 1, as dim_order is ambiguous in that case
4366
constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = {
@@ -91,25 +114,36 @@ bool validate_depthwise_conv2d_arguments(
91114
return false;
92115
}
93116

94-
// Depthwise convolution constraint: groups must equal input channels
117+
// CMSIS-NN depthwise convolution does not support dilation != 1
118+
if (dilation[0] != 1 || dilation[1] != 1) {
119+
ET_LOG(
120+
Error,
121+
"quantized_depthwise_conv2d_out: dilation != 1 not supported, got (%zd, %zd)",
122+
dilation[0],
123+
dilation[1]);
124+
context.fail(Error::InvalidArgument);
125+
return false;
126+
}
127+
95128
const int64_t input_channels = input.size(1);
96-
if (groups != input_channels) {
129+
// output_channels already extracted above for weight validation
130+
if (output_channels != input_channels * depth_multiplier) {
97131
ET_LOG(
98132
Error,
99-
"quantized_depthwise_conv2d_out: groups (%zd) must equal input channels (%zd) for depthwise convolution",
100-
groups,
101-
input_channels);
133+
"quantized_depthwise_conv2d_out: output channels (%zd) must equal input channels (%zd) * depth_multiplier (%zd)",
134+
output_channels,
135+
input_channels,
136+
depth_multiplier);
102137
context.fail(Error::InvalidArgument);
103138
return false;
104139
}
105140

106-
const int64_t out_channels = output.size(1);
107-
if (requantize_multipliers.size(0) != out_channels ||
108-
requantize_shifts.size(0) != out_channels) {
141+
if (requantize_multipliers.size(0) != output_channels ||
142+
requantize_shifts.size(0) != output_channels) {
109143
ET_LOG(
110144
Error,
111145
"quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)",
112-
out_channels);
146+
output_channels);
113147
context.fail(Error::InvalidArgument);
114148
return false;
115149
}
@@ -126,7 +160,7 @@ Tensor& quantized_depthwise_conv2d_out(
126160
const IntArrayRef stride,
127161
const IntArrayRef padding,
128162
const IntArrayRef dilation,
129-
const int64_t groups,
163+
const int64_t depth_multiplier,
130164
const int64_t input_offset,
131165
const int64_t output_offset,
132166
const Tensor& requantize_multipliers,
@@ -143,7 +177,7 @@ Tensor& quantized_depthwise_conv2d_out(
143177
stride,
144178
padding,
145179
dilation,
146-
groups,
180+
depth_multiplier,
147181
requantize_multipliers,
148182
requantize_shifts)) {
149183
return out;
@@ -154,15 +188,17 @@ Tensor& quantized_depthwise_conv2d_out(
154188
const int32_t input_height = static_cast<int32_t>(input.size(2));
155189
const int32_t input_width = static_cast<int32_t>(input.size(3));
156190

157-
const int32_t kernel_output_channels = static_cast<int32_t>(weight.size(0));
191+
// Weight is in IHWO layout after permutation in the pass: [1, H, W, C_OUT]
192+
// For depthwise conv, this matches CMSIS-NN's expected format
158193
const int32_t kernel_height = static_cast<int32_t>(weight.size(1));
159194
const int32_t kernel_width = static_cast<int32_t>(weight.size(2));
160-
const int32_t depth_multiplier = static_cast<int32_t>(weight.size(3));
161195

162196
const int32_t output_channels = static_cast<int32_t>(out.size(1));
163197
const int32_t output_height = static_cast<int32_t>(out.size(2));
164198
const int32_t output_width = static_cast<int32_t>(out.size(3));
165199

200+
const int32_t depth_multiplier_val = static_cast<int32_t>(depth_multiplier);
201+
166202
const int32_t input_offset_val = static_cast<int32_t>(input_offset);
167203
const int32_t output_offset_val = static_cast<int32_t>(output_offset);
168204
const int32_t activation_min_val = static_cast<int32_t>(activation_min);
@@ -179,7 +215,7 @@ Tensor& quantized_depthwise_conv2d_out(
179215
cmsis_nn_dw_conv_params dw_conv_params;
180216
dw_conv_params.input_offset = input_offset_val;
181217
dw_conv_params.output_offset = output_offset_val;
182-
dw_conv_params.ch_mult = depth_multiplier;
218+
dw_conv_params.ch_mult = depth_multiplier_val;
183219
dw_conv_params.stride.h = static_cast<const int32_t>(stride[0]);
184220
dw_conv_params.stride.w = static_cast<const int32_t>(stride[1]);
185221
dw_conv_params.padding.h = static_cast<const int32_t>(padding[0]);
@@ -203,24 +239,23 @@ Tensor& quantized_depthwise_conv2d_out(
203239
cmsis_context.buf = nullptr;
204240
cmsis_context.size = 0;
205241

206-
const size_t buffer_bytes = static_cast<size_t>(
207-
arm_depthwise_conv_s8_get_buffer_size(&input_dims, &filter_dims));
242+
const size_t buffer_bytes =
243+
static_cast<size_t>(arm_depthwise_conv_wrapper_s8_get_buffer_size(
244+
&dw_conv_params, &input_dims, &filter_dims, &output_dims));
208245
if (buffer_bytes > 0) {
209246
auto buffer_or_error =
210247
context.allocate_temp(buffer_bytes, alignof(int16_t));
211248
if (!buffer_or_error.ok()) {
212-
if (buffer_or_error.error() != Error::NotFound) {
213-
ET_LOG(
214-
Error,
215-
"quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d)",
216-
static_cast<int>(buffer_or_error.error()));
217-
context.fail(buffer_or_error.error());
218-
return out;
219-
}
220-
} else {
221-
cmsis_context.buf = buffer_or_error.get();
222-
cmsis_context.size = buffer_bytes;
249+
ET_LOG(
250+
Error,
251+
"quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)",
252+
static_cast<int>(buffer_bytes),
253+
static_cast<int>(buffer_or_error.error()));
254+
context.fail(buffer_or_error.error());
255+
return out;
223256
}
257+
cmsis_context.buf = buffer_or_error.get();
258+
cmsis_context.size = buffer_bytes;
224259
}
225260

226261
const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8(

backends/cortex_m/ops/operators.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,35 @@ def _compute_conv2d_output_shape(
488488
return torch.Size([batch, out_channels, out_height, out_width])
489489

490490

491+
def _compute_depthwise_conv2d_output_shape(
492+
input_shape: torch.Size,
493+
weight_shape: torch.Size,
494+
stride: Sequence[int],
495+
padding: Sequence[int],
496+
dilation: Sequence[int],
497+
) -> torch.Size:
498+
batch = input_shape[0]
499+
in_height = input_shape[2]
500+
in_width = input_shape[3]
501+
# For depthwise conv, we store the weights in IHWO layout (1, kernel_h, kernel_w, out)
502+
# where dimension 3 contains the output channels
503+
kernel_height = weight_shape[1]
504+
kernel_width = weight_shape[2]
505+
506+
stride_h, stride_w = stride
507+
pad_h, pad_w = padding
508+
dilation_h, dilation_w = dilation
509+
510+
out_channels = weight_shape[3] # IHWO format: output channels at dimension 3
511+
out_height = (
512+
in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1
513+
) // stride_h + 1
514+
out_width = (
515+
in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1
516+
) // stride_w + 1
517+
return torch.Size([batch, out_channels, out_height, out_width])
518+
519+
491520
@register_fake("cortex_m::quantized_conv2d")
492521
def quantized_conv2d_meta(
493522
input: torch.Tensor,
@@ -591,7 +620,7 @@ def quantized_conv2d_impl(
591620
"int[] stride, "
592621
"int[] padding, "
593622
"int[] dilation, "
594-
"int groups, "
623+
"int depth_multiplier, "
595624
"int input_offset, "
596625
"int output_offset, "
597626
"Tensor requantize_multipliers, "
@@ -610,7 +639,7 @@ def quantized_conv2d_impl(
610639
"int[] stride, "
611640
"int[] padding, "
612641
"int[] dilation, "
613-
"int groups, "
642+
"int depth_multiplier, "
614643
"int input_offset, "
615644
"int output_offset, "
616645
"Tensor requantize_multipliers, "
@@ -630,7 +659,7 @@ def quantized_depthwise_conv2d_meta(
630659
stride: Sequence[int],
631660
padding: Sequence[int],
632661
dilation: Sequence[int],
633-
groups: int,
662+
depth_multiplier: int,
634663
input_offset: int,
635664
output_offset: int,
636665
requantize_multipliers: torch.Tensor,
@@ -641,7 +670,7 @@ def quantized_depthwise_conv2d_meta(
641670
stride_vals = list(stride)
642671
padding_vals = list(padding)
643672
dilation_vals = list(dilation)
644-
output_shape = _compute_conv2d_output_shape(
673+
output_shape = _compute_depthwise_conv2d_output_shape(
645674
input.shape, weight.shape, stride_vals, padding_vals, dilation_vals
646675
)
647676
return torch.empty(
@@ -660,7 +689,7 @@ def quantized_depthwise_conv2d_impl(
660689
stride: Sequence[int],
661690
padding: Sequence[int],
662691
dilation: Sequence[int],
663-
groups: int,
692+
depth_multiplier: int,
664693
input_offset: int,
665694
output_offset: int,
666695
requantize_multipliers: torch.Tensor,
@@ -673,26 +702,26 @@ def quantized_depthwise_conv2d_impl(
673702
"quantized_depthwise_conv2d expects 4D input and weight tensors"
674703
)
675704

676-
# Validate depthwise convolution constraint: groups == input_channels
677705
input_channels = input.shape[1]
678-
if groups != input_channels:
679-
raise RuntimeError(
680-
f"quantized_depthwise_conv2d: groups ({groups}) must equal input channels ({input_channels})"
681-
)
706+
groups = input_channels
682707

683708
# Convert to int32 for accumulation and apply offsets
684709
input_int32 = input.to(torch.int32) + int(input_offset)
685710
weight_int32 = weight.to(torch.int32)
686711

687712
if bias is None:
688713
bias_int32 = torch.zeros(
689-
weight.shape[0], dtype=torch.int32, device=input.device
714+
weight.shape[3],
715+
dtype=torch.int32,
716+
device=input.device, # C_OUT is at dim 3 in IHWO
690717
)
691718
else:
692719
bias_int32 = bias.to(torch.int32)
693720

694-
# Convert weights back to OIHW layout expected by torch.nn.functional.conv2d
695-
weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous()
721+
# Weight is in IHWO layout: [1, H, W, C_OUT]
722+
# Convert to OIHW layout expected by torch.nn.functional.conv2d
723+
# IHWO [1, H, W, C_OUT] -> OIHW [C_OUT, 1, H, W]
724+
weight_oi_hw = weight_int32.permute(3, 0, 1, 2).contiguous()
696725

697726
# Depthwise convolution has groups == input_channels
698727
conv_acc = F.conv2d(

backends/cortex_m/ops/operators.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
- arg_meta: null
6060
kernel_name: cortex_m::quantized_conv2d_out
6161

62-
- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!)
62+
- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!)
6363
variants: function
6464
kernels:
6565
- arg_meta: null

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,24 @@ def _get_convolution_replacement(self, node) -> int:
156156
quantized_multipliers.append(quantized_multiplier)
157157
quantized_shifts.append(quantized_shift)
158158

159-
# Permute the weight tensor to the OHWI layout expected by CMSIS-NN.
160159
weight_tensor = get_param_tensor(self.exported_program, weight)
161-
weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous(
162-
memory_format=torch.channels_last
163-
)
160+
161+
# Detect depthwise convolution:
162+
# PyTorch depthwise weight is [out_ch, 1, H, W] where dimension 1 is 1
163+
# and groups == input_channels (groups > 1)
164+
is_depthwise = weight_tensor.shape[1] == 1 and groups > 1
165+
166+
if is_depthwise:
167+
# For depthwise: OIHW -> IHWO which gives [1, H, W, C_OUT] for CMSIS-NN
168+
# PyTorch depthwise weight is [out_ch, 1, H, W], permute to [1, H, W, out_ch]
169+
weight_permuted = weight_tensor.permute(1, 2, 3, 0).contiguous(
170+
memory_format=torch.channels_last
171+
)
172+
else:
173+
# For regular conv: OIHW -> OHWI
174+
weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous(
175+
memory_format=torch.channels_last
176+
)
164177

165178
with node.graph.inserting_after(weight):
166179
weight_nhwc = create_constant_placeholder(
@@ -187,21 +200,28 @@ def _get_convolution_replacement(self, node) -> int:
187200
torch.tensor(quantized_shifts, dtype=torch.int32),
188201
)
189202

190-
# Detect depthwise convolution: groups == input_channels
191-
input_tensor = get_first_fake_tensor(x)
192-
input_channels = input_tensor.shape[1]
193-
is_depthwise = groups == input_channels
194-
195203
if is_depthwise:
196-
# Use depthwise convolution operator
204+
# Compute depth_multiplier for depthwise convolution
205+
# For depthwise: output_channels = input_channels * depth_multiplier
206+
# PyTorch depthwise weight is [C_OUT, 1, H, W]
207+
output_channels = weight_tensor.shape[0]
208+
input_channels = groups # For depthwise, groups == input_channels
209+
210+
if output_channels % input_channels != 0:
211+
raise ValueError(
212+
f"Depthwise conv: output_channels ({output_channels}) must be "
213+
f"divisible by input_channels ({input_channels})"
214+
)
215+
depth_multiplier = output_channels // input_channels
216+
197217
new_args = (
198218
x,
199219
weight_nhwc,
200220
bias,
201221
stride,
202222
padding,
203223
dilation,
204-
groups,
224+
depth_multiplier,
205225
-input_zero_point,
206226
output_zero_point,
207227
quantized_multiplier_tensor,

0 commit comments

Comments
 (0)