Skip to content

Commit c013786

Browse files
author
RJ Ascani
committed
Cortex-M backend: Add depthwise conv2d operator
Add quantized depthwise convolution operator for the Cortex-M backend using CMSIS-NN's optimized arm_depthwise_conv_wrapper_s8 function. Key changes: - New op_quantized_depthwise_conv2d.cpp with CMSIS-NN implementation - Python operator registration in operators.py with reference implementation - Operator schema definition in operators.yaml - Updated ConvertToCortexMPass to automatically detect and route depthwise convolutions (where groups == input_channels) to the specialized operator - Comprehensive test coverage with 5 test cases covering different depthwise convolution scenarios (stride, padding, bias, depth multiplier) The implementation validates the depthwise constraint (groups must equal input channels) and supports NHWC layout, int8 quantization, per-channel requantization, and configurable stride/padding/dilation parameters.
1 parent c493e2d commit c013786

File tree

6 files changed

+513
-15
lines changed

6 files changed

+513
-15
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
5757
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_depthwise_conv2d.cpp
6061
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6162
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6263
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
extern "C" {
11+
#include "arm_nnfunctions.h"
12+
}
13+
14+
namespace cortex_m {
15+
namespace native {
16+
17+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
18+
19+
namespace {
20+
constexpr int64_t kConvDim = 4;
21+
22+
bool validate_depthwise_conv2d_arguments(
23+
KernelRuntimeContext& context,
24+
const Tensor& input,
25+
const Tensor& weight,
26+
const torch::executor::optional<Tensor>& bias,
27+
const Tensor& output,
28+
const IntArrayRef& stride,
29+
const IntArrayRef& padding,
30+
const IntArrayRef& dilation,
31+
const int64_t groups,
32+
const Tensor& requantize_multipliers,
33+
const Tensor& requantize_shifts) {
34+
if (input.dim() != kConvDim || weight.dim() != kConvDim ||
35+
output.dim() != kConvDim) {
36+
ET_LOG(Error, "quantized_depthwise_conv2d_out: tensors must be 4-D");
37+
context.fail(Error::InvalidArgument);
38+
return false;
39+
}
40+
41+
// Check for channels_last dim_order (NHWC: 0, 2, 3, 1)
42+
// Skip check if channels == 1, as dim_order is ambiguous in that case
43+
constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = {
44+
0, 2, 3, 1};
45+
executorch::aten::ArrayRef<executorch::aten::DimOrderType>
46+
channels_last_order(kChannelsLastDimOrder, 4);
47+
48+
if (input.size(1) > 1 && input.dim_order() != channels_last_order) {
49+
ET_LOG(
50+
Error,
51+
"quantized_depthwise_conv2d_out: input must have channels_last dim_order (NHWC)");
52+
context.fail(Error::InvalidArgument);
53+
return false;
54+
}
55+
56+
if (output.size(1) > 1 && output.dim_order() != channels_last_order) {
57+
ET_LOG(
58+
Error,
59+
"quantized_depthwise_conv2d_out: output must have channels_last dim_order (NHWC)");
60+
context.fail(Error::InvalidArgument);
61+
return false;
62+
}
63+
64+
if (input.scalar_type() != ScalarType::Char ||
65+
output.scalar_type() != ScalarType::Char) {
66+
ET_LOG(
67+
Error, "quantized_depthwise_conv2d_out: input and output must be int8");
68+
context.fail(Error::InvalidArgument);
69+
return false;
70+
}
71+
72+
if (weight.scalar_type() != ScalarType::Char) {
73+
ET_LOG(Error, "quantized_depthwise_conv2d_out: weight must be int8");
74+
context.fail(Error::InvalidArgument);
75+
return false;
76+
}
77+
78+
if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) {
79+
ET_LOG(
80+
Error, "quantized_depthwise_conv2d_out: bias must be int32 if provided");
81+
context.fail(Error::InvalidArgument);
82+
return false;
83+
}
84+
85+
if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) {
86+
ET_LOG(
87+
Error,
88+
"quantized_depthwise_conv2d_out: stride, padding, and dilation must have length 2");
89+
context.fail(Error::InvalidArgument);
90+
return false;
91+
}
92+
93+
// Depthwise convolution constraint: groups must equal input channels
94+
const int64_t input_channels = input.size(1);
95+
if (groups != input_channels) {
96+
ET_LOG(
97+
Error,
98+
"quantized_depthwise_conv2d_out: groups (%zd) must equal input channels (%zd) for depthwise convolution",
99+
groups,
100+
input_channels);
101+
context.fail(Error::InvalidArgument);
102+
return false;
103+
}
104+
105+
const int64_t out_channels = output.size(1);
106+
if (requantize_multipliers.size(0) != out_channels ||
107+
requantize_shifts.size(0) != out_channels) {
108+
ET_LOG(
109+
Error,
110+
"quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)",
111+
out_channels);
112+
context.fail(Error::InvalidArgument);
113+
return false;
114+
}
115+
116+
return true;
117+
}
118+
} // namespace
119+
120+
Tensor& quantized_depthwise_conv2d_out(
121+
KernelRuntimeContext& context,
122+
const Tensor& input,
123+
const Tensor& weight,
124+
const torch::executor::optional<Tensor>& bias,
125+
const IntArrayRef stride,
126+
const IntArrayRef padding,
127+
const IntArrayRef dilation,
128+
const int64_t groups,
129+
const int64_t input_offset,
130+
const int64_t output_offset,
131+
const Tensor& requantize_multipliers,
132+
const Tensor& requantize_shifts,
133+
const int64_t activation_min,
134+
const int64_t activation_max,
135+
Tensor& out) {
136+
if (!validate_depthwise_conv2d_arguments(
137+
context,
138+
input,
139+
weight,
140+
bias,
141+
out,
142+
stride,
143+
padding,
144+
dilation,
145+
groups,
146+
requantize_multipliers,
147+
requantize_shifts)) {
148+
return out;
149+
}
150+
151+
const int32_t batch = static_cast<int32_t>(input.size(0));
152+
const int32_t input_channels = static_cast<int32_t>(input.size(1));
153+
const int32_t input_height = static_cast<int32_t>(input.size(2));
154+
const int32_t input_width = static_cast<int32_t>(input.size(3));
155+
156+
const int32_t kernel_output_channels = static_cast<int32_t>(weight.size(0));
157+
const int32_t kernel_height = static_cast<int32_t>(weight.size(1));
158+
const int32_t kernel_width = static_cast<int32_t>(weight.size(2));
159+
const int32_t depth_multiplier = static_cast<int32_t>(weight.size(3));
160+
161+
const int32_t output_channels = static_cast<int32_t>(out.size(1));
162+
const int32_t output_height = static_cast<int32_t>(out.size(2));
163+
const int32_t output_width = static_cast<int32_t>(out.size(3));
164+
165+
const int32_t input_offset_val = static_cast<int32_t>(input_offset);
166+
const int32_t output_offset_val = static_cast<int32_t>(output_offset);
167+
const int32_t activation_min_val = static_cast<int32_t>(activation_min);
168+
const int32_t activation_max_val = static_cast<int32_t>(activation_max);
169+
170+
const cmsis_nn_dims input_dims{
171+
batch, input_height, input_width, input_channels};
172+
const cmsis_nn_dims filter_dims{
173+
1, kernel_height, kernel_width, output_channels};
174+
const cmsis_nn_dims output_dims{
175+
batch, output_height, output_width, output_channels};
176+
const cmsis_nn_dims bias_dims{1, 1, 1, output_channels};
177+
178+
cmsis_nn_dw_conv_params dw_conv_params;
179+
dw_conv_params.input_offset = input_offset_val;
180+
dw_conv_params.output_offset = output_offset_val;
181+
dw_conv_params.ch_mult = depth_multiplier;
182+
dw_conv_params.stride.h = static_cast<const int32_t>(stride[0]);
183+
dw_conv_params.stride.w = static_cast<const int32_t>(stride[1]);
184+
dw_conv_params.padding.h = static_cast<const int32_t>(padding[0]);
185+
dw_conv_params.padding.w = static_cast<const int32_t>(padding[1]);
186+
dw_conv_params.dilation.h = static_cast<const int32_t>(dilation[0]);
187+
dw_conv_params.dilation.w = static_cast<const int32_t>(dilation[1]);
188+
dw_conv_params.activation.min = activation_min_val;
189+
dw_conv_params.activation.max = activation_max_val;
190+
191+
cmsis_nn_per_channel_quant_params quant_params;
192+
quant_params.multiplier = requantize_multipliers.data_ptr<int32_t>();
193+
quant_params.shift = requantize_shifts.data_ptr<int32_t>();
194+
195+
const int8_t* input_data = input.const_data_ptr<int8_t>();
196+
const int8_t* weight_data = weight.const_data_ptr<int8_t>();
197+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
198+
const int32_t* bias_data =
199+
bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr;
200+
201+
cmsis_nn_context cmsis_context;
202+
cmsis_context.buf = nullptr;
203+
cmsis_context.size = 0;
204+
205+
const size_t buffer_bytes = static_cast<size_t>(
206+
arm_depthwise_conv_s8_get_buffer_size(&input_dims, &filter_dims));
207+
if (buffer_bytes > 0) {
208+
auto buffer_or_error =
209+
context.allocate_temp(buffer_bytes, alignof(int16_t));
210+
if (!buffer_or_error.ok()) {
211+
if (buffer_or_error.error() != Error::NotFound) {
212+
ET_LOG(
213+
Error,
214+
"quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d)",
215+
static_cast<int>(buffer_or_error.error()));
216+
context.fail(buffer_or_error.error());
217+
return out;
218+
}
219+
} else {
220+
cmsis_context.buf = buffer_or_error.get();
221+
cmsis_context.size = buffer_bytes;
222+
}
223+
}
224+
225+
const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8(
226+
&cmsis_context,
227+
&dw_conv_params,
228+
&quant_params,
229+
&input_dims,
230+
input_data,
231+
&filter_dims,
232+
weight_data,
233+
&bias_dims,
234+
bias_data,
235+
&output_dims,
236+
output_data);
237+
238+
if (status != ARM_CMSIS_NN_SUCCESS) {
239+
ET_LOG(
240+
Error,
241+
"quantized_depthwise_conv2d_out: arm_depthwise_conv_wrapper_s8 failed with status %d",
242+
status);
243+
context.fail(Error::Internal);
244+
}
245+
246+
return out;
247+
}
248+
249+
} // namespace native
250+
} // namespace cortex_m

0 commit comments

Comments
 (0)