|
| 1 | +/* |
| 2 | + * Copyright 2025 Arm Limited and/or its affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the BSD-style license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +#include "cortex_m_ops_common.h" |
| 9 | + |
| 10 | +extern "C" { |
| 11 | +#include "arm_nnfunctions.h" |
| 12 | +} |
| 13 | + |
| 14 | +namespace cortex_m { |
| 15 | +namespace native { |
| 16 | + |
| 17 | +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; |
| 18 | + |
| 19 | +namespace { |
| 20 | +constexpr int64_t kConvDim = 4; |
| 21 | + |
| 22 | +bool validate_depthwise_conv2d_arguments( |
| 23 | + KernelRuntimeContext& context, |
| 24 | + const Tensor& input, |
| 25 | + const Tensor& weight, |
| 26 | + const torch::executor::optional<Tensor>& bias, |
| 27 | + const Tensor& output, |
| 28 | + const IntArrayRef& stride, |
| 29 | + const IntArrayRef& padding, |
| 30 | + const IntArrayRef& dilation, |
| 31 | + const int64_t groups, |
| 32 | + const Tensor& requantize_multipliers, |
| 33 | + const Tensor& requantize_shifts) { |
| 34 | + if (input.dim() != kConvDim || weight.dim() != kConvDim || |
| 35 | + output.dim() != kConvDim) { |
| 36 | + ET_LOG(Error, "quantized_depthwise_conv2d_out: tensors must be 4-D"); |
| 37 | + context.fail(Error::InvalidArgument); |
| 38 | + return false; |
| 39 | + } |
| 40 | + |
| 41 | + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) |
| 42 | + // Skip check if channels == 1, as dim_order is ambiguous in that case |
| 43 | + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { |
| 44 | + 0, 2, 3, 1}; |
| 45 | + executorch::aten::ArrayRef<executorch::aten::DimOrderType> |
| 46 | + channels_last_order(kChannelsLastDimOrder, 4); |
| 47 | + |
| 48 | + if (input.size(1) > 1 && input.dim_order() != channels_last_order) { |
| 49 | + ET_LOG( |
| 50 | + Error, |
| 51 | + "quantized_depthwise_conv2d_out: input must have channels_last dim_order (NHWC)"); |
| 52 | + context.fail(Error::InvalidArgument); |
| 53 | + return false; |
| 54 | + } |
| 55 | + |
| 56 | + if (output.size(1) > 1 && output.dim_order() != channels_last_order) { |
| 57 | + ET_LOG( |
| 58 | + Error, |
| 59 | + "quantized_depthwise_conv2d_out: output must have channels_last dim_order (NHWC)"); |
| 60 | + context.fail(Error::InvalidArgument); |
| 61 | + return false; |
| 62 | + } |
| 63 | + |
| 64 | + if (input.scalar_type() != ScalarType::Char || |
| 65 | + output.scalar_type() != ScalarType::Char) { |
| 66 | + ET_LOG( |
| 67 | + Error, "quantized_depthwise_conv2d_out: input and output must be int8"); |
| 68 | + context.fail(Error::InvalidArgument); |
| 69 | + return false; |
| 70 | + } |
| 71 | + |
| 72 | + if (weight.scalar_type() != ScalarType::Char) { |
| 73 | + ET_LOG(Error, "quantized_depthwise_conv2d_out: weight must be int8"); |
| 74 | + context.fail(Error::InvalidArgument); |
| 75 | + return false; |
| 76 | + } |
| 77 | + |
| 78 | + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { |
| 79 | + ET_LOG( |
| 80 | + Error, "quantized_depthwise_conv2d_out: bias must be int32 if provided"); |
| 81 | + context.fail(Error::InvalidArgument); |
| 82 | + return false; |
| 83 | + } |
| 84 | + |
| 85 | + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { |
| 86 | + ET_LOG( |
| 87 | + Error, |
| 88 | + "quantized_depthwise_conv2d_out: stride, padding, and dilation must have length 2"); |
| 89 | + context.fail(Error::InvalidArgument); |
| 90 | + return false; |
| 91 | + } |
| 92 | + |
| 93 | + // Depthwise convolution constraint: groups must equal input channels |
| 94 | + const int64_t input_channels = input.size(1); |
| 95 | + if (groups != input_channels) { |
| 96 | + ET_LOG( |
| 97 | + Error, |
| 98 | + "quantized_depthwise_conv2d_out: groups (%zd) must equal input channels (%zd) for depthwise convolution", |
| 99 | + groups, |
| 100 | + input_channels); |
| 101 | + context.fail(Error::InvalidArgument); |
| 102 | + return false; |
| 103 | + } |
| 104 | + |
| 105 | + const int64_t out_channels = output.size(1); |
| 106 | + if (requantize_multipliers.size(0) != out_channels || |
| 107 | + requantize_shifts.size(0) != out_channels) { |
| 108 | + ET_LOG( |
| 109 | + Error, |
| 110 | + "quantized_depthwise_conv2d_out: per-channel params must match output channels (%zd)", |
| 111 | + out_channels); |
| 112 | + context.fail(Error::InvalidArgument); |
| 113 | + return false; |
| 114 | + } |
| 115 | + |
| 116 | + return true; |
| 117 | +} |
| 118 | +} // namespace |
| 119 | + |
| 120 | +Tensor& quantized_depthwise_conv2d_out( |
| 121 | + KernelRuntimeContext& context, |
| 122 | + const Tensor& input, |
| 123 | + const Tensor& weight, |
| 124 | + const torch::executor::optional<Tensor>& bias, |
| 125 | + const IntArrayRef stride, |
| 126 | + const IntArrayRef padding, |
| 127 | + const IntArrayRef dilation, |
| 128 | + const int64_t groups, |
| 129 | + const int64_t input_offset, |
| 130 | + const int64_t output_offset, |
| 131 | + const Tensor& requantize_multipliers, |
| 132 | + const Tensor& requantize_shifts, |
| 133 | + const int64_t activation_min, |
| 134 | + const int64_t activation_max, |
| 135 | + Tensor& out) { |
| 136 | + if (!validate_depthwise_conv2d_arguments( |
| 137 | + context, |
| 138 | + input, |
| 139 | + weight, |
| 140 | + bias, |
| 141 | + out, |
| 142 | + stride, |
| 143 | + padding, |
| 144 | + dilation, |
| 145 | + groups, |
| 146 | + requantize_multipliers, |
| 147 | + requantize_shifts)) { |
| 148 | + return out; |
| 149 | + } |
| 150 | + |
| 151 | + const int32_t batch = static_cast<int32_t>(input.size(0)); |
| 152 | + const int32_t input_channels = static_cast<int32_t>(input.size(1)); |
| 153 | + const int32_t input_height = static_cast<int32_t>(input.size(2)); |
| 154 | + const int32_t input_width = static_cast<int32_t>(input.size(3)); |
| 155 | + |
| 156 | + const int32_t kernel_output_channels = static_cast<int32_t>(weight.size(0)); |
| 157 | + const int32_t kernel_height = static_cast<int32_t>(weight.size(1)); |
| 158 | + const int32_t kernel_width = static_cast<int32_t>(weight.size(2)); |
| 159 | + const int32_t depth_multiplier = static_cast<int32_t>(weight.size(3)); |
| 160 | + |
| 161 | + const int32_t output_channels = static_cast<int32_t>(out.size(1)); |
| 162 | + const int32_t output_height = static_cast<int32_t>(out.size(2)); |
| 163 | + const int32_t output_width = static_cast<int32_t>(out.size(3)); |
| 164 | + |
| 165 | + const int32_t input_offset_val = static_cast<int32_t>(input_offset); |
| 166 | + const int32_t output_offset_val = static_cast<int32_t>(output_offset); |
| 167 | + const int32_t activation_min_val = static_cast<int32_t>(activation_min); |
| 168 | + const int32_t activation_max_val = static_cast<int32_t>(activation_max); |
| 169 | + |
| 170 | + const cmsis_nn_dims input_dims{ |
| 171 | + batch, input_height, input_width, input_channels}; |
| 172 | + const cmsis_nn_dims filter_dims{ |
| 173 | + 1, kernel_height, kernel_width, output_channels}; |
| 174 | + const cmsis_nn_dims output_dims{ |
| 175 | + batch, output_height, output_width, output_channels}; |
| 176 | + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; |
| 177 | + |
| 178 | + cmsis_nn_dw_conv_params dw_conv_params; |
| 179 | + dw_conv_params.input_offset = input_offset_val; |
| 180 | + dw_conv_params.output_offset = output_offset_val; |
| 181 | + dw_conv_params.ch_mult = depth_multiplier; |
| 182 | + dw_conv_params.stride.h = static_cast<const int32_t>(stride[0]); |
| 183 | + dw_conv_params.stride.w = static_cast<const int32_t>(stride[1]); |
| 184 | + dw_conv_params.padding.h = static_cast<const int32_t>(padding[0]); |
| 185 | + dw_conv_params.padding.w = static_cast<const int32_t>(padding[1]); |
| 186 | + dw_conv_params.dilation.h = static_cast<const int32_t>(dilation[0]); |
| 187 | + dw_conv_params.dilation.w = static_cast<const int32_t>(dilation[1]); |
| 188 | + dw_conv_params.activation.min = activation_min_val; |
| 189 | + dw_conv_params.activation.max = activation_max_val; |
| 190 | + |
| 191 | + cmsis_nn_per_channel_quant_params quant_params; |
| 192 | + quant_params.multiplier = requantize_multipliers.data_ptr<int32_t>(); |
| 193 | + quant_params.shift = requantize_shifts.data_ptr<int32_t>(); |
| 194 | + |
| 195 | + const int8_t* input_data = input.const_data_ptr<int8_t>(); |
| 196 | + const int8_t* weight_data = weight.const_data_ptr<int8_t>(); |
| 197 | + int8_t* output_data = out.mutable_data_ptr<int8_t>(); |
| 198 | + const int32_t* bias_data = |
| 199 | + bias.has_value() ? bias.value().const_data_ptr<int32_t>() : nullptr; |
| 200 | + |
| 201 | + cmsis_nn_context cmsis_context; |
| 202 | + cmsis_context.buf = nullptr; |
| 203 | + cmsis_context.size = 0; |
| 204 | + |
| 205 | + const size_t buffer_bytes = static_cast<size_t>( |
| 206 | + arm_depthwise_conv_s8_get_buffer_size(&input_dims, &filter_dims)); |
| 207 | + if (buffer_bytes > 0) { |
| 208 | + auto buffer_or_error = |
| 209 | + context.allocate_temp(buffer_bytes, alignof(int16_t)); |
| 210 | + if (!buffer_or_error.ok()) { |
| 211 | + if (buffer_or_error.error() != Error::NotFound) { |
| 212 | + ET_LOG( |
| 213 | + Error, |
| 214 | + "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d)", |
| 215 | + static_cast<int>(buffer_or_error.error())); |
| 216 | + context.fail(buffer_or_error.error()); |
| 217 | + return out; |
| 218 | + } |
| 219 | + } else { |
| 220 | + cmsis_context.buf = buffer_or_error.get(); |
| 221 | + cmsis_context.size = buffer_bytes; |
| 222 | + } |
| 223 | + } |
| 224 | + |
| 225 | + const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( |
| 226 | + &cmsis_context, |
| 227 | + &dw_conv_params, |
| 228 | + &quant_params, |
| 229 | + &input_dims, |
| 230 | + input_data, |
| 231 | + &filter_dims, |
| 232 | + weight_data, |
| 233 | + &bias_dims, |
| 234 | + bias_data, |
| 235 | + &output_dims, |
| 236 | + output_data); |
| 237 | + |
| 238 | + if (status != ARM_CMSIS_NN_SUCCESS) { |
| 239 | + ET_LOG( |
| 240 | + Error, |
| 241 | + "quantized_depthwise_conv2d_out: arm_depthwise_conv_wrapper_s8 failed with status %d", |
| 242 | + status); |
| 243 | + context.fail(Error::Internal); |
| 244 | + } |
| 245 | + |
| 246 | + return out; |
| 247 | +} |
| 248 | + |
| 249 | +} // namespace native |
| 250 | +} // namespace cortex_m |
0 commit comments