From ce5ac47e23289b7df4c17697713fbe6880531358 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Wed, 10 Dec 2025 15:52:52 -0800 Subject: [PATCH 1/7] Cortex-M backend: Add depthwise conv2d operator Add quantized depthwise convolution operator for the Cortex-M backend using CMSIS-NN's optimized arm_depthwise_conv_wrapper_s8 function. Key changes: - New op_quantized_depthwise_conv2d.cpp with CMSIS-NN implementation - Python operator registration in operators.py with reference implementation - Operator schema definition in operators.yaml - Updated ConvertToCortexMPass to automatically detect and route depthwise convolutions (where groups == input_channels) to the specialized operator - Comprehensive test coverage with 5 test cases covering different depthwise convolution scenarios (stride, padding, bias, depth multiplier) The implementation validates the depthwise constraint (groups must equal input channels) and supports NHWC layout, int8 quantization, per-channel requantization, and configurable stride/padding/dilation parameters. --- backends/cortex_m/CMakeLists.txt | 1 + .../ops/op_quantized_depthwise_conv2d.cpp | 250 ++++++++++++++++++ backends/cortex_m/ops/operators.py | 141 ++++++++++ backends/cortex_m/ops/operators.yaml | 6 + .../passes/convert_to_cortex_m_pass.py | 55 ++-- backends/cortex_m/test/ops/test_conv.py | 75 ++++++ 6 files changed, 513 insertions(+), 15 deletions(-) create mode 100644 backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index ac330d4b015..7d3f7d47bf5 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_depthwise_conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp new file mode 100644 index 00000000000..0622a823d99 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -0,0 +1,250 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { +constexpr int64_t kConvDim = 4; + +bool validate_depthwise_conv2d_arguments( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const Tensor& output, + const IntArrayRef& stride, + const IntArrayRef& padding, + const IntArrayRef& dilation, + const int64_t groups, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts) { + if (input.dim() != kConvDim || weight.dim() != kConvDim || + output.dim() != kConvDim) { + ET_LOG(Error, "quantized_depthwise_conv2d_out: tensors must be 4-D"); + context.fail(Error::InvalidArgument); + return false; + } + + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) + // Skip check if channels == 1, as dim_order is ambiguous in that case + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: input must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: output must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (input.scalar_type() != ScalarType::Char || + output.scalar_type() != ScalarType::Char) { + ET_LOG( + Error, "quantized_depthwise_conv2d_out: input and output must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (weight.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_depthwise_conv2d_out: weight must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { + ET_LOG( + Error, "quantized_depthwise_conv2d_out: bias must be int32 if provided"); + context.fail(Error::InvalidArgument); + return false; + } + + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: stride, padding, and dilation must have length 2"); + context.fail(Error::InvalidArgument); + return false; + } + + // Depthwise convolution constraint: groups must equal input channels + const int64_t input_channels = input.size(1); + if (groups != input_channels) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: groups (%zd) must equal input channels (%zd) for depthwise convolution", + groups, + input_channels); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t out_channels = output.size(1); + if (requantize_multipliers.size(0) != out_channels || + requantize_shifts.size(0) != out_channels) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)", + out_channels); + context.fail(Error::InvalidArgument); + return false; + } + + return true; +} +} // namespace + +Tensor& quantized_depthwise_conv2d_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t groups, + const int64_t input_offset, + const int64_t output_offset, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts, + const int64_t activation_min, + const int64_t activation_max, + Tensor& out) { + if (!validate_depthwise_conv2d_arguments( + context, + input, + weight, + bias, + out, + stride, + padding, + dilation, + groups, + requantize_multipliers, + requantize_shifts)) { + return out; + } + + const int32_t batch = static_cast(input.size(0)); + const int32_t input_channels = static_cast(input.size(1)); + const int32_t input_height = static_cast(input.size(2)); + const int32_t input_width = static_cast(input.size(3)); + + const int32_t kernel_output_channels = static_cast(weight.size(0)); + const int32_t kernel_height = static_cast(weight.size(1)); + const int32_t kernel_width = static_cast(weight.size(2)); + const int32_t depth_multiplier = static_cast(weight.size(3)); + + const int32_t output_channels = static_cast(out.size(1)); + const int32_t output_height = static_cast(out.size(2)); + const int32_t output_width = static_cast(out.size(3)); + + const int32_t input_offset_val = static_cast(input_offset); + const int32_t output_offset_val = static_cast(output_offset); + const int32_t activation_min_val = static_cast(activation_min); + const int32_t activation_max_val = static_cast(activation_max); + + const cmsis_nn_dims input_dims{ + batch, input_height, input_width, input_channels}; + const cmsis_nn_dims filter_dims{ + 1, kernel_height, kernel_width, output_channels}; + const cmsis_nn_dims output_dims{ + batch, output_height, output_width, output_channels}; + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; + + cmsis_nn_dw_conv_params dw_conv_params; + dw_conv_params.input_offset = input_offset_val; + dw_conv_params.output_offset = output_offset_val; + dw_conv_params.ch_mult = depth_multiplier; + dw_conv_params.stride.h = static_cast(stride[0]); + dw_conv_params.stride.w = static_cast(stride[1]); + dw_conv_params.padding.h = static_cast(padding[0]); + dw_conv_params.padding.w = static_cast(padding[1]); + dw_conv_params.dilation.h = static_cast(dilation[0]); + dw_conv_params.dilation.w = static_cast(dilation[1]); + dw_conv_params.activation.min = activation_min_val; + dw_conv_params.activation.max = activation_max_val; + + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = requantize_multipliers.data_ptr(); + quant_params.shift = requantize_shifts.data_ptr(); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weight.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + + cmsis_nn_context cmsis_context; + cmsis_context.buf = nullptr; + cmsis_context.size = 0; + + const size_t buffer_bytes = static_cast( + arm_depthwise_conv_s8_get_buffer_size(&input_dims, &filter_dims)); + if (buffer_bytes > 0) { + auto buffer_or_error = + context.allocate_temp(buffer_bytes, alignof(int16_t)); + if (!buffer_or_error.ok()) { + if (buffer_or_error.error() != Error::NotFound) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d)", + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; + } + } else { + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; + } + } + + const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( + &cmsis_context, + &dw_conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: arm_depthwise_conv_wrapper_s8 failed with status %d", + status); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 291615f613a..1ad35119a8f 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -577,3 +577,144 @@ def quantized_conv2d_impl( result = torch.clamp(result, activation_min, activation_max) return result.to(torch.int8) + + +# =================================================================== +# QUANTIZED DEPTHWISE CONV2D OPERATION DEFINITION +# =================================================================== + +lib.define( + "quantized_depthwise_conv2d(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int groups, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max" + ") -> Tensor" +) + + +lib.define( + "quantized_depthwise_conv2d.out(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int groups, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + + +@register_fake("cortex_m::quantized_depthwise_conv2d") +def quantized_depthwise_conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + groups: int, + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + stride_vals = list(stride) + padding_vals = list(padding) + dilation_vals = list(dilation) + output_shape = _compute_conv2d_output_shape( + input.shape, weight.shape, stride_vals, padding_vals, dilation_vals + ) + return torch.empty( + output_shape, + dtype=torch.int8, + device=input.device, + memory_format=torch.channels_last, + ) + + +@impl(lib, "quantized_depthwise_conv2d", "CompositeExplicitAutograd") +def quantized_depthwise_conv2d_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + groups: int, + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + if input.dim() != 4 or weight.dim() != 4: + raise RuntimeError("quantized_depthwise_conv2d expects 4D input and weight tensors") + + # Validate depthwise convolution constraint: groups == input_channels + input_channels = input.shape[1] + if groups != input_channels: + raise RuntimeError( + f"quantized_depthwise_conv2d: groups ({groups}) must equal input channels ({input_channels})" + ) + + # Convert to int32 for accumulation and apply offsets + input_int32 = input.to(torch.int32) + int(input_offset) + weight_int32 = weight.to(torch.int32) + + if bias is None: + bias_int32 = torch.zeros( + weight.shape[0], dtype=torch.int32, device=input.device + ) + else: + bias_int32 = bias.to(torch.int32) + + # Convert weights back to OIHW layout expected by torch.nn.functional.conv2d + weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous() + + # Depthwise convolution has groups == input_channels + conv_acc = F.conv2d( + input_int32, + weight_oi_hw, + bias_int32, + stride=tuple(stride), + padding=tuple(padding), + dilation=tuple(dilation), + groups=groups, + ) + + result_channels = [] + for output_channel_i in range(conv_acc.shape[1]): + result_channel = requantize_cmsis( + conv_acc[:, output_channel_i, :, :], + int(requantize_multipliers[output_channel_i]), + int(requantize_shifts[output_channel_i]), + ) + result_channels.append(result_channel) + + result = torch.stack(result_channels, dim=1) + + result += output_offset + result = torch.clamp(result, activation_min, activation_max) + + return result.to(torch.int8) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 0b0b2f5c715..aa67ab3457c 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -58,3 +58,9 @@ kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_depthwise_conv2d_out diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 5a142efd639..307b4c1a368 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -187,21 +187,46 @@ def _get_convolution_replacement(self, node) -> int: torch.tensor(quantized_shifts, dtype=torch.int32), ) - new_args = ( - x, - weight_nhwc, - bias, - stride, - padding, - dilation, - -input_zero_point, - output_zero_point, - quantized_multiplier_tensor, - quantized_shift_tensor, - output_qmin, - output_qmax, - ) - return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + # Detect depthwise convolution: groups == input_channels + input_tensor = get_first_fake_tensor(x) + input_channels = input_tensor.shape[1] + is_depthwise = groups == input_channels + + if is_depthwise: + # Use depthwise convolution operator + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + groups, + -input_zero_point, + output_zero_point, + quantized_multiplier_tensor, + quantized_shift_tensor, + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args + else: + # Use regular convolution operator + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + -input_zero_point, + output_zero_point, + quantized_multiplier_tensor, + quantized_shift_tensor, + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index 5630abbdab3..5b000432daf 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -112,6 +112,50 @@ def forward(self, x): return x +class CortexMDepthwiseConv2D(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_depthwise_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=False) + + def forward(self, x): + return self.conv(x) + + +class CortexMDepthwiseConv2DBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_depthwise_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=True) + + def forward(self, x): + return self.conv(x) + + # in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode test_cases = { "conv2d": McuTestCase( @@ -178,6 +222,37 @@ def forward(self, x): ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), ), ), + # Depthwise convolution tests (groups == in_channels) + "depthwise_conv2d": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 3, groups=4), + example_inputs=( + ramp_tensor(1, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_multiplier": McuTestCase( + model=CortexMDepthwiseConv2D(3, 6, 3, groups=3), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_stride": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 3, stride=2, groups=4), + example_inputs=( + ramp_tensor(-50, 50, (2, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_padding": McuTestCase( + model=CortexMDepthwiseConv2D(2, 2, 5, padding=2, groups=2), + example_inputs=( + ramp_tensor(0, 1, (1, 2, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_bias": McuTestCase( + model=CortexMDepthwiseConv2DBias(3, 3, 3, padding=1, groups=3), + example_inputs=( + ramp_tensor(-10, 10, (1, 3, 6, 6)).to(memory_format=torch.channels_last), + ), + ), } From bce750c4707f76f1079b67231b8286add3fad40a Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Wed, 10 Dec 2025 16:13:30 -0800 Subject: [PATCH 2/7] Fix formatting --- backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp | 3 ++- backends/cortex_m/ops/operators.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 0622a823d99..359c79a0f46 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -77,7 +77,8 @@ bool validate_depthwise_conv2d_arguments( if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { ET_LOG( - Error, "quantized_depthwise_conv2d_out: bias must be int32 if provided"); + Error, + "quantized_depthwise_conv2d_out: bias must be int32 if provided"); context.fail(Error::InvalidArgument); return false; } diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 1ad35119a8f..5853d2a3ea1 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -669,7 +669,9 @@ def quantized_depthwise_conv2d_impl( activation_max: int, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: - raise RuntimeError("quantized_depthwise_conv2d expects 4D input and weight tensors") + raise RuntimeError( + "quantized_depthwise_conv2d expects 4D input and weight tensors" + ) # Validate depthwise convolution constraint: groups == input_channels input_channels = input.shape[1] From b7eca03139eccc63d004ea4d8d875382e780ce40 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Thu, 11 Dec 2025 10:07:21 -0800 Subject: [PATCH 3/7] Improve depthwise conv2d implementation with AOT optimizations and validations Key changes: - Move depth_multiplier calculation from runtime to AOT pass (eliminates runtime division by computing depth_multiplier = output_channels / input_channels in the graph transformation pass) - Add critical defensive validations in validate_depthwise_conv2d_arguments(): * Validate IHWO weight layout (dimension 0 must be 1) * Validate dilation == 1 (CMSIS-NN constraint) * Validate depth_multiplier consistency with channel counts - Fix CMSIS-NN API usage: * Use arm_depthwise_conv_wrapper_s8_get_buffer_size() with correct parameters * Improve buffer allocation error handling with detailed error messages - Add _compute_depthwise_conv2d_output_shape() to read channels from correct dimension (dim 3 for IHWO layout vs dim 0 for OHWI) - Update operator schema to use depth_multiplier parameter instead of groups This ensures proper validation of CMSIS-NN constraints and moves computation to compile-time where possible. --- .../ops/op_quantized_depthwise_conv2d.cpp | 93 +++++++++++++------ backends/cortex_m/ops/operators.py | 55 ++++++++--- backends/cortex_m/ops/operators.yaml | 2 +- .../passes/convert_to_cortex_m_pass.py | 42 ++++++--- 4 files changed, 138 insertions(+), 54 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 359c79a0f46..5b0bbcb79d0 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -1,5 +1,6 @@ /* - * Copyright 2025 Arm Limited and/or its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -28,7 +29,7 @@ bool validate_depthwise_conv2d_arguments( const IntArrayRef& stride, const IntArrayRef& padding, const IntArrayRef& dilation, - const int64_t groups, + const int64_t depth_multiplier, const Tensor& requantize_multipliers, const Tensor& requantize_shifts) { if (input.dim() != kConvDim || weight.dim() != kConvDim || @@ -38,6 +39,28 @@ bool validate_depthwise_conv2d_arguments( return false; } + // Validate weight is in IHWO layout: [1, H, W, C_OUT] + if (weight.size(0) != 1) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: weight dim 0 must be 1 for IHWO layout, got %zd", + weight.size(0)); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t weight_output_channels = weight.size(3); + const int64_t output_channels = output.size(1); + if (weight_output_channels != output_channels) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: weight output channels (%zd) must match output channels (%zd)", + weight_output_channels, + output_channels); + context.fail(Error::InvalidArgument); + return false; + } + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) // Skip check if channels == 1, as dim_order is ambiguous in that case constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { @@ -91,25 +114,36 @@ bool validate_depthwise_conv2d_arguments( return false; } - // Depthwise convolution constraint: groups must equal input channels + // CMSIS-NN depthwise convolution does not support dilation != 1 + if (dilation[0] != 1 || dilation[1] != 1) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: dilation != 1 not supported, got (%zd, %zd)", + dilation[0], + dilation[1]); + context.fail(Error::InvalidArgument); + return false; + } + const int64_t input_channels = input.size(1); - if (groups != input_channels) { + // output_channels already extracted above for weight validation + if (output_channels != input_channels * depth_multiplier) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: groups (%zd) must equal input channels (%zd) for depthwise convolution", - groups, - input_channels); + "quantized_depthwise_conv2d_out: output channels (%zd) must equal input channels (%zd) * depth_multiplier (%zd)", + output_channels, + input_channels, + depth_multiplier); context.fail(Error::InvalidArgument); return false; } - const int64_t out_channels = output.size(1); - if (requantize_multipliers.size(0) != out_channels || - requantize_shifts.size(0) != out_channels) { + if (requantize_multipliers.size(0) != output_channels || + requantize_shifts.size(0) != output_channels) { ET_LOG( Error, "quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)", - out_channels); + output_channels); context.fail(Error::InvalidArgument); return false; } @@ -126,7 +160,7 @@ Tensor& quantized_depthwise_conv2d_out( const IntArrayRef stride, const IntArrayRef padding, const IntArrayRef dilation, - const int64_t groups, + const int64_t depth_multiplier, const int64_t input_offset, const int64_t output_offset, const Tensor& requantize_multipliers, @@ -143,7 +177,7 @@ Tensor& quantized_depthwise_conv2d_out( stride, padding, dilation, - groups, + depth_multiplier, requantize_multipliers, requantize_shifts)) { return out; @@ -154,15 +188,17 @@ Tensor& quantized_depthwise_conv2d_out( const int32_t input_height = static_cast(input.size(2)); const int32_t input_width = static_cast(input.size(3)); - const int32_t kernel_output_channels = static_cast(weight.size(0)); + // Weight is in IHWO layout after permutation in the pass: [1, H, W, C_OUT] + // For depthwise conv, this matches CMSIS-NN's expected format const int32_t kernel_height = static_cast(weight.size(1)); const int32_t kernel_width = static_cast(weight.size(2)); - const int32_t depth_multiplier = static_cast(weight.size(3)); const int32_t output_channels = static_cast(out.size(1)); const int32_t output_height = static_cast(out.size(2)); const int32_t output_width = static_cast(out.size(3)); + const int32_t depth_multiplier_val = static_cast(depth_multiplier); + const int32_t input_offset_val = static_cast(input_offset); const int32_t output_offset_val = static_cast(output_offset); const int32_t activation_min_val = static_cast(activation_min); @@ -179,7 +215,7 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_nn_dw_conv_params dw_conv_params; dw_conv_params.input_offset = input_offset_val; dw_conv_params.output_offset = output_offset_val; - dw_conv_params.ch_mult = depth_multiplier; + dw_conv_params.ch_mult = depth_multiplier_val; dw_conv_params.stride.h = static_cast(stride[0]); dw_conv_params.stride.w = static_cast(stride[1]); dw_conv_params.padding.h = static_cast(padding[0]); @@ -203,24 +239,23 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_context.buf = nullptr; cmsis_context.size = 0; - const size_t buffer_bytes = static_cast( - arm_depthwise_conv_s8_get_buffer_size(&input_dims, &filter_dims)); + const size_t buffer_bytes = + static_cast(arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims)); if (buffer_bytes > 0) { auto buffer_or_error = context.allocate_temp(buffer_bytes, alignof(int16_t)); if (!buffer_or_error.ok()) { - if (buffer_or_error.error() != Error::NotFound) { - ET_LOG( - Error, - "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d)", - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); - return out; - } - } else { - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", + static_cast(buffer_bytes), + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; } + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; } const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 5853d2a3ea1..221f84c792c 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -488,6 +488,35 @@ def _compute_conv2d_output_shape( return torch.Size([batch, out_channels, out_height, out_width]) +def _compute_depthwise_conv2d_output_shape( + input_shape: torch.Size, + weight_shape: torch.Size, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> torch.Size: + batch = input_shape[0] + in_height = input_shape[2] + in_width = input_shape[3] + # For depthwise conv, we store the weights in IHWO layout (1, kernel_h, kernel_w, out) + # where dimension 3 contains the output channels + kernel_height = weight_shape[1] + kernel_width = weight_shape[2] + + stride_h, stride_w = stride + pad_h, pad_w = padding + dilation_h, dilation_w = dilation + + out_channels = weight_shape[3] # IHWO format: output channels at dimension 3 + out_height = ( + in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1 + ) // stride_h + 1 + out_width = ( + in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1 + ) // stride_w + 1 + return torch.Size([batch, out_channels, out_height, out_width]) + + @register_fake("cortex_m::quantized_conv2d") def quantized_conv2d_meta( input: torch.Tensor, @@ -591,7 +620,7 @@ def quantized_conv2d_impl( "int[] stride, " "int[] padding, " "int[] dilation, " - "int groups, " + "int depth_multiplier, " "int input_offset, " "int output_offset, " "Tensor requantize_multipliers, " @@ -610,7 +639,7 @@ def quantized_conv2d_impl( "int[] stride, " "int[] padding, " "int[] dilation, " - "int groups, " + "int depth_multiplier, " "int input_offset, " "int output_offset, " "Tensor requantize_multipliers, " @@ -630,7 +659,7 @@ def quantized_depthwise_conv2d_meta( stride: Sequence[int], padding: Sequence[int], dilation: Sequence[int], - groups: int, + depth_multiplier: int, input_offset: int, output_offset: int, requantize_multipliers: torch.Tensor, @@ -641,7 +670,7 @@ def quantized_depthwise_conv2d_meta( stride_vals = list(stride) padding_vals = list(padding) dilation_vals = list(dilation) - output_shape = _compute_conv2d_output_shape( + output_shape = _compute_depthwise_conv2d_output_shape( input.shape, weight.shape, stride_vals, padding_vals, dilation_vals ) return torch.empty( @@ -660,7 +689,7 @@ def quantized_depthwise_conv2d_impl( stride: Sequence[int], padding: Sequence[int], dilation: Sequence[int], - groups: int, + depth_multiplier: int, input_offset: int, output_offset: int, requantize_multipliers: torch.Tensor, @@ -673,12 +702,8 @@ def quantized_depthwise_conv2d_impl( "quantized_depthwise_conv2d expects 4D input and weight tensors" ) - # Validate depthwise convolution constraint: groups == input_channels input_channels = input.shape[1] - if groups != input_channels: - raise RuntimeError( - f"quantized_depthwise_conv2d: groups ({groups}) must equal input channels ({input_channels})" - ) + groups = input_channels # Convert to int32 for accumulation and apply offsets input_int32 = input.to(torch.int32) + int(input_offset) @@ -686,13 +711,17 @@ def quantized_depthwise_conv2d_impl( if bias is None: bias_int32 = torch.zeros( - weight.shape[0], dtype=torch.int32, device=input.device + weight.shape[3], + dtype=torch.int32, + device=input.device, # C_OUT is at dim 3 in IHWO ) else: bias_int32 = bias.to(torch.int32) - # Convert weights back to OIHW layout expected by torch.nn.functional.conv2d - weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous() + # Weight is in IHWO layout: [1, H, W, C_OUT] + # Convert to OIHW layout expected by torch.nn.functional.conv2d + # IHWO [1, H, W, C_OUT] -> OIHW [C_OUT, 1, H, W] + weight_oi_hw = weight_int32.permute(3, 0, 1, 2).contiguous() # Depthwise convolution has groups == input_channels conv_acc = F.conv2d( diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index aa67ab3457c..2fe51e6a02e 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -59,7 +59,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out -- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 307b4c1a368..94ead0de68d 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -156,11 +156,24 @@ def _get_convolution_replacement(self, node) -> int: quantized_multipliers.append(quantized_multiplier) quantized_shifts.append(quantized_shift) - # Permute the weight tensor to the OHWI layout expected by CMSIS-NN. weight_tensor = get_param_tensor(self.exported_program, weight) - weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( - memory_format=torch.channels_last - ) + + # Detect depthwise convolution: + # PyTorch depthwise weight is [out_ch, 1, H, W] where dimension 1 is 1 + # and groups == input_channels (groups > 1) + is_depthwise = weight_tensor.shape[1] == 1 and groups > 1 + + if is_depthwise: + # For depthwise: OIHW -> IHWO which gives [1, H, W, C_OUT] for CMSIS-NN + # PyTorch depthwise weight is [out_ch, 1, H, W], permute to [1, H, W, out_ch] + weight_permuted = weight_tensor.permute(1, 2, 3, 0).contiguous( + memory_format=torch.channels_last + ) + else: + # For regular conv: OIHW -> OHWI + weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( + memory_format=torch.channels_last + ) with node.graph.inserting_after(weight): weight_nhwc = create_constant_placeholder( @@ -187,13 +200,20 @@ def _get_convolution_replacement(self, node) -> int: torch.tensor(quantized_shifts, dtype=torch.int32), ) - # Detect depthwise convolution: groups == input_channels - input_tensor = get_first_fake_tensor(x) - input_channels = input_tensor.shape[1] - is_depthwise = groups == input_channels - if is_depthwise: - # Use depthwise convolution operator + # Compute depth_multiplier for depthwise convolution + # For depthwise: output_channels = input_channels * depth_multiplier + # PyTorch depthwise weight is [C_OUT, 1, H, W] + output_channels = weight_tensor.shape[0] + input_channels = groups # For depthwise, groups == input_channels + + if output_channels % input_channels != 0: + raise ValueError( + f"Depthwise conv: output_channels ({output_channels}) must be " + f"divisible by input_channels ({input_channels})" + ) + depth_multiplier = output_channels // input_channels + new_args = ( x, weight_nhwc, @@ -201,7 +221,7 @@ def _get_convolution_replacement(self, node) -> int: stride, padding, dilation, - groups, + depth_multiplier, -input_zero_point, output_zero_point, quantized_multiplier_tensor, From a5d5e0133578fb7a8bf5c74c4bd331540a9ffb90 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Fri, 12 Dec 2025 14:43:06 -0800 Subject: [PATCH 4/7] Add batch size validation and test coverage for depthwise conv2d CMSIS-NN arm_depthwise_conv_wrapper_s8 only supports batch size 1. Add validation in both AOT pass (fail during compilation) and runtime (defensive check). Add 6 test cases covering edge cases: - Combined stride/padding/bias - 1x1 kernels (common in mobile networks) - Higher depth_multiplier (4) - Asymmetric kernels (1x3) - Asymmetric stride/padding - Larger kernels (5x5) Fix depthwise_conv2d_stride test to use batch size 1. --- .../ops/op_quantized_depthwise_conv2d.cpp | 10 +++++ .../passes/convert_to_cortex_m_pass.py | 8 ++++ backends/cortex_m/test/ops/test_conv.py | 38 ++++++++++++++++++- 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 5b0bbcb79d0..7997fa0c4cc 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -39,6 +39,16 @@ bool validate_depthwise_conv2d_arguments( return false; } + // CMSIS-NN depthwise convolution only supports batch size of 1 + if (input.size(0) != 1) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: CMSIS-NN only supports batch size 1, got %zd", + input.size(0)); + context.fail(Error::InvalidArgument); + return false; + } + // Validate weight is in IHWO layout: [1, H, W, C_OUT] if (weight.size(0) != 1) { ET_LOG( diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 94ead0de68d..8a6f279e487 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -207,6 +207,14 @@ def _get_convolution_replacement(self, node) -> int: output_channels = weight_tensor.shape[0] input_channels = groups # For depthwise, groups == input_channels + # CMSIS-NN depthwise convolution only supports batch size of 1 + input_tensor = get_first_fake_tensor(x) + batch_size = input_tensor.shape[0] + if batch_size != 1: + raise ValueError( + f"Depthwise conv: CMSIS-NN only supports batch size 1, got {batch_size}" + ) + if output_channels % input_channels != 0: raise ValueError( f"Depthwise conv: output_channels ({output_channels}) must be " diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index 5b000432daf..b43505f3b16 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -238,7 +238,7 @@ def forward(self, x): "depthwise_conv2d_stride": McuTestCase( model=CortexMDepthwiseConv2D(4, 4, 3, stride=2, groups=4), example_inputs=( - ramp_tensor(-50, 50, (2, 4, 8, 8)).to(memory_format=torch.channels_last), + ramp_tensor(-50, 50, (1, 4, 8, 8)).to(memory_format=torch.channels_last), ), ), "depthwise_conv2d_padding": McuTestCase( @@ -253,6 +253,42 @@ def forward(self, x): ramp_tensor(-10, 10, (1, 3, 6, 6)).to(memory_format=torch.channels_last), ), ), + "depthwise_conv2d_stride_padding_bias": McuTestCase( + model=CortexMDepthwiseConv2DBias(4, 4, 3, stride=2, padding=1, groups=4), + example_inputs=( + ramp_tensor(0, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_1x1": McuTestCase( + model=CortexMDepthwiseConv2D(4, 8, 1, groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_multiplier_4": McuTestCase( + model=CortexMDepthwiseConv2D(2, 8, 3, groups=2), + example_inputs=( + ramp_tensor(0, 10, (1, 2, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_asymmetric_kernel": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, (1, 3), groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_asymmetric_stride": McuTestCase( + model=CortexMDepthwiseConv2D(3, 3, 3, stride=(2, 1), padding=(1, 0), groups=3), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_5x5": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 5, padding=2, groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), } From 44341531e148885839638f119dd772bdb754e000 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Fri, 12 Dec 2025 17:17:36 -0800 Subject: [PATCH 5/7] Use is_channels_last_tensor helper function --- .../cortex_m/ops/op_quantized_depthwise_conv2d.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 7997fa0c4cc..184ccc50c26 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -71,14 +71,7 @@ bool validate_depthwise_conv2d_arguments( return false; } - // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) - // Skip check if channels == 1, as dim_order is ambiguous in that case - constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { - 0, 2, 3, 1}; - executorch::aten::ArrayRef - channels_last_order(kChannelsLastDimOrder, 4); - - if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + if (!is_channels_last_tensor(input)) { ET_LOG( Error, "quantized_depthwise_conv2d_out: input must have channels_last dim_order (NHWC)"); @@ -86,7 +79,7 @@ bool validate_depthwise_conv2d_arguments( return false; } - if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + if (!is_channels_last_tensor(output)) { ET_LOG( Error, "quantized_depthwise_conv2d_out: output must have channels_last dim_order (NHWC)"); From fb34949e9ac5bb90d183a2b984f4b5bd24d06bc8 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Mon, 15 Dec 2025 11:40:08 -0800 Subject: [PATCH 6/7] Remove invalid dilation check --- .../cortex_m/ops/op_quantized_depthwise_conv2d.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 184ccc50c26..b0fdef5ff4f 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -117,17 +117,6 @@ bool validate_depthwise_conv2d_arguments( return false; } - // CMSIS-NN depthwise convolution does not support dilation != 1 - if (dilation[0] != 1 || dilation[1] != 1) { - ET_LOG( - Error, - "quantized_depthwise_conv2d_out: dilation != 1 not supported, got (%zd, %zd)", - dilation[0], - dilation[1]); - context.fail(Error::InvalidArgument); - return false; - } - const int64_t input_channels = input.size(1); // output_channels already extracted above for weight validation if (output_channels != input_channels * depth_multiplier) { From 577364c38e31bcca575314871141f5ee3b3d0383 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Mon, 15 Dec 2025 11:44:05 -0800 Subject: [PATCH 7/7] Shorten error messages --- .../ops/op_quantized_depthwise_conv2d.cpp | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index b0fdef5ff4f..310cffe8fb9 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -43,7 +43,7 @@ bool validate_depthwise_conv2d_arguments( if (input.size(0) != 1) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: CMSIS-NN only supports batch size 1, got %zd", + "quantized_depthwise_conv2d_out: batch size must be 1, got %zd", input.size(0)); context.fail(Error::InvalidArgument); return false; @@ -53,7 +53,7 @@ bool validate_depthwise_conv2d_arguments( if (weight.size(0) != 1) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: weight dim 0 must be 1 for IHWO layout, got %zd", + "quantized_depthwise_conv2d_out: weight dim 0 must be 1, got %zd", weight.size(0)); context.fail(Error::InvalidArgument); return false; @@ -64,7 +64,7 @@ bool validate_depthwise_conv2d_arguments( if (weight_output_channels != output_channels) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: weight output channels (%zd) must match output channels (%zd)", + "quantized_depthwise_conv2d_out: weight out_ch (%zd) != out_ch (%zd)", weight_output_channels, output_channels); context.fail(Error::InvalidArgument); @@ -73,16 +73,14 @@ bool validate_depthwise_conv2d_arguments( if (!is_channels_last_tensor(input)) { ET_LOG( - Error, - "quantized_depthwise_conv2d_out: input must have channels_last dim_order (NHWC)"); + Error, "quantized_depthwise_conv2d_out: input must be channels_last"); context.fail(Error::InvalidArgument); return false; } if (!is_channels_last_tensor(output)) { ET_LOG( - Error, - "quantized_depthwise_conv2d_out: output must have channels_last dim_order (NHWC)"); + Error, "quantized_depthwise_conv2d_out: output must be channels_last"); context.fail(Error::InvalidArgument); return false; } @@ -102,9 +100,7 @@ bool validate_depthwise_conv2d_arguments( } if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { - ET_LOG( - Error, - "quantized_depthwise_conv2d_out: bias must be int32 if provided"); + ET_LOG(Error, "quantized_depthwise_conv2d_out: bias must be int32"); context.fail(Error::InvalidArgument); return false; } @@ -112,7 +108,7 @@ bool validate_depthwise_conv2d_arguments( if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: stride, padding, and dilation must have length 2"); + "quantized_depthwise_conv2d_out: stride/padding/dilation must have length 2"); context.fail(Error::InvalidArgument); return false; } @@ -122,7 +118,7 @@ bool validate_depthwise_conv2d_arguments( if (output_channels != input_channels * depth_multiplier) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: output channels (%zd) must equal input channels (%zd) * depth_multiplier (%zd)", + "quantized_depthwise_conv2d_out: out_ch (%zd) != in_ch (%zd) * depth_mult (%zd)", output_channels, input_channels, depth_multiplier); @@ -134,7 +130,7 @@ bool validate_depthwise_conv2d_arguments( requantize_shifts.size(0) != output_channels) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)", + "quantized_depthwise_conv2d_out: per-ch params size != out_ch (%zd)", output_channels); context.fail(Error::InvalidArgument); return false;