diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index ac330d4b015..7d3f7d47bf5 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_depthwise_conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp new file mode 100644 index 00000000000..310cffe8fb9 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { +constexpr int64_t kConvDim = 4; + +bool validate_depthwise_conv2d_arguments( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const Tensor& output, + const IntArrayRef& stride, + const IntArrayRef& padding, + const IntArrayRef& dilation, + const int64_t depth_multiplier, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts) { + if (input.dim() != kConvDim || weight.dim() != kConvDim || + output.dim() != kConvDim) { + ET_LOG(Error, "quantized_depthwise_conv2d_out: tensors must be 4-D"); + context.fail(Error::InvalidArgument); + return false; + } + + // CMSIS-NN depthwise convolution only supports batch size of 1 + if (input.size(0) != 1) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: batch size must be 1, got %zd", + input.size(0)); + context.fail(Error::InvalidArgument); + return false; + } + + // Validate weight is in IHWO layout: [1, H, W, C_OUT] + if (weight.size(0) != 1) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: weight dim 0 must be 1, got %zd", + weight.size(0)); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t weight_output_channels = weight.size(3); + const int64_t output_channels = output.size(1); + if (weight_output_channels != output_channels) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: weight out_ch (%zd) != out_ch (%zd)", + weight_output_channels, + output_channels); + context.fail(Error::InvalidArgument); + return false; + } + + if (!is_channels_last_tensor(input)) { + ET_LOG( + Error, "quantized_depthwise_conv2d_out: input must be channels_last"); + context.fail(Error::InvalidArgument); + return false; + } + + if (!is_channels_last_tensor(output)) { + ET_LOG( + Error, "quantized_depthwise_conv2d_out: output must be channels_last"); + context.fail(Error::InvalidArgument); + return false; + } + + if (input.scalar_type() != ScalarType::Char || + output.scalar_type() != ScalarType::Char) { + ET_LOG( + Error, "quantized_depthwise_conv2d_out: input and output must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (weight.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_depthwise_conv2d_out: weight must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { + ET_LOG(Error, "quantized_depthwise_conv2d_out: bias must be int32"); + context.fail(Error::InvalidArgument); + return false; + } + + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: stride/padding/dilation must have length 2"); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t input_channels = input.size(1); + // output_channels already extracted above for weight validation + if (output_channels != input_channels * depth_multiplier) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: out_ch (%zd) != in_ch (%zd) * depth_mult (%zd)", + output_channels, + input_channels, + depth_multiplier); + context.fail(Error::InvalidArgument); + return false; + } + + if (requantize_multipliers.size(0) != output_channels || + requantize_shifts.size(0) != output_channels) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: per-ch params size != out_ch (%zd)", + output_channels); + context.fail(Error::InvalidArgument); + return false; + } + + return true; +} +} // namespace + +Tensor& quantized_depthwise_conv2d_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t depth_multiplier, + const int64_t input_offset, + const int64_t output_offset, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts, + const int64_t activation_min, + const int64_t activation_max, + Tensor& out) { + if (!validate_depthwise_conv2d_arguments( + context, + input, + weight, + bias, + out, + stride, + padding, + dilation, + depth_multiplier, + requantize_multipliers, + requantize_shifts)) { + return out; + } + + const int32_t batch = static_cast(input.size(0)); + const int32_t input_channels = static_cast(input.size(1)); + const int32_t input_height = static_cast(input.size(2)); + const int32_t input_width = static_cast(input.size(3)); + + // Weight is in IHWO layout after permutation in the pass: [1, H, W, C_OUT] + // For depthwise conv, this matches CMSIS-NN's expected format + const int32_t kernel_height = static_cast(weight.size(1)); + const int32_t kernel_width = static_cast(weight.size(2)); + + const int32_t output_channels = static_cast(out.size(1)); + const int32_t output_height = static_cast(out.size(2)); + const int32_t output_width = static_cast(out.size(3)); + + const int32_t depth_multiplier_val = static_cast(depth_multiplier); + + const int32_t input_offset_val = static_cast(input_offset); + const int32_t output_offset_val = static_cast(output_offset); + const int32_t activation_min_val = static_cast(activation_min); + const int32_t activation_max_val = static_cast(activation_max); + + const cmsis_nn_dims input_dims{ + batch, input_height, input_width, input_channels}; + const cmsis_nn_dims filter_dims{ + 1, kernel_height, kernel_width, output_channels}; + const cmsis_nn_dims output_dims{ + batch, output_height, output_width, output_channels}; + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; + + cmsis_nn_dw_conv_params dw_conv_params; + dw_conv_params.input_offset = input_offset_val; + dw_conv_params.output_offset = output_offset_val; + dw_conv_params.ch_mult = depth_multiplier_val; + dw_conv_params.stride.h = static_cast(stride[0]); + dw_conv_params.stride.w = static_cast(stride[1]); + dw_conv_params.padding.h = static_cast(padding[0]); + dw_conv_params.padding.w = static_cast(padding[1]); + dw_conv_params.dilation.h = static_cast(dilation[0]); + dw_conv_params.dilation.w = static_cast(dilation[1]); + dw_conv_params.activation.min = activation_min_val; + dw_conv_params.activation.max = activation_max_val; + + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = requantize_multipliers.data_ptr(); + quant_params.shift = requantize_shifts.data_ptr(); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weight.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + + cmsis_nn_context cmsis_context; + cmsis_context.buf = nullptr; + cmsis_context.size = 0; + + const size_t buffer_bytes = + static_cast(arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims)); + if (buffer_bytes > 0) { + auto buffer_or_error = + context.allocate_temp(buffer_bytes, alignof(int16_t)); + if (!buffer_or_error.ok()) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", + static_cast(buffer_bytes), + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; + } + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; + } + + const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( + &cmsis_context, + &dw_conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_depthwise_conv2d_out: arm_depthwise_conv_wrapper_s8 failed with status %d", + status); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 291615f613a..221f84c792c 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -488,6 +488,35 @@ def _compute_conv2d_output_shape( return torch.Size([batch, out_channels, out_height, out_width]) +def _compute_depthwise_conv2d_output_shape( + input_shape: torch.Size, + weight_shape: torch.Size, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> torch.Size: + batch = input_shape[0] + in_height = input_shape[2] + in_width = input_shape[3] + # For depthwise conv, we store the weights in IHWO layout (1, kernel_h, kernel_w, out) + # where dimension 3 contains the output channels + kernel_height = weight_shape[1] + kernel_width = weight_shape[2] + + stride_h, stride_w = stride + pad_h, pad_w = padding + dilation_h, dilation_w = dilation + + out_channels = weight_shape[3] # IHWO format: output channels at dimension 3 + out_height = ( + in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1 + ) // stride_h + 1 + out_width = ( + in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1 + ) // stride_w + 1 + return torch.Size([batch, out_channels, out_height, out_width]) + + @register_fake("cortex_m::quantized_conv2d") def quantized_conv2d_meta( input: torch.Tensor, @@ -577,3 +606,146 @@ def quantized_conv2d_impl( result = torch.clamp(result, activation_min, activation_max) return result.to(torch.int8) + + +# =================================================================== +# QUANTIZED DEPTHWISE CONV2D OPERATION DEFINITION +# =================================================================== + +lib.define( + "quantized_depthwise_conv2d(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int depth_multiplier, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max" + ") -> Tensor" +) + + +lib.define( + "quantized_depthwise_conv2d.out(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int depth_multiplier, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + + +@register_fake("cortex_m::quantized_depthwise_conv2d") +def quantized_depthwise_conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + depth_multiplier: int, + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + stride_vals = list(stride) + padding_vals = list(padding) + dilation_vals = list(dilation) + output_shape = _compute_depthwise_conv2d_output_shape( + input.shape, weight.shape, stride_vals, padding_vals, dilation_vals + ) + return torch.empty( + output_shape, + dtype=torch.int8, + device=input.device, + memory_format=torch.channels_last, + ) + + +@impl(lib, "quantized_depthwise_conv2d", "CompositeExplicitAutograd") +def quantized_depthwise_conv2d_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + depth_multiplier: int, + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + if input.dim() != 4 or weight.dim() != 4: + raise RuntimeError( + "quantized_depthwise_conv2d expects 4D input and weight tensors" + ) + + input_channels = input.shape[1] + groups = input_channels + + # Convert to int32 for accumulation and apply offsets + input_int32 = input.to(torch.int32) + int(input_offset) + weight_int32 = weight.to(torch.int32) + + if bias is None: + bias_int32 = torch.zeros( + weight.shape[3], + dtype=torch.int32, + device=input.device, # C_OUT is at dim 3 in IHWO + ) + else: + bias_int32 = bias.to(torch.int32) + + # Weight is in IHWO layout: [1, H, W, C_OUT] + # Convert to OIHW layout expected by torch.nn.functional.conv2d + # IHWO [1, H, W, C_OUT] -> OIHW [C_OUT, 1, H, W] + weight_oi_hw = weight_int32.permute(3, 0, 1, 2).contiguous() + + # Depthwise convolution has groups == input_channels + conv_acc = F.conv2d( + input_int32, + weight_oi_hw, + bias_int32, + stride=tuple(stride), + padding=tuple(padding), + dilation=tuple(dilation), + groups=groups, + ) + + result_channels = [] + for output_channel_i in range(conv_acc.shape[1]): + result_channel = requantize_cmsis( + conv_acc[:, output_channel_i, :, :], + int(requantize_multipliers[output_channel_i]), + int(requantize_shifts[output_channel_i]), + ) + result_channels.append(result_channel) + + result = torch.stack(result_channels, dim=1) + + result += output_offset + result = torch.clamp(result, activation_min, activation_max) + + return result.to(torch.int8) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 0b0b2f5c715..2fe51e6a02e 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -58,3 +58,9 @@ kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_depthwise_conv2d_out diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 5a142efd639..8a6f279e487 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -156,11 +156,24 @@ def _get_convolution_replacement(self, node) -> int: quantized_multipliers.append(quantized_multiplier) quantized_shifts.append(quantized_shift) - # Permute the weight tensor to the OHWI layout expected by CMSIS-NN. weight_tensor = get_param_tensor(self.exported_program, weight) - weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( - memory_format=torch.channels_last - ) + + # Detect depthwise convolution: + # PyTorch depthwise weight is [out_ch, 1, H, W] where dimension 1 is 1 + # and groups == input_channels (groups > 1) + is_depthwise = weight_tensor.shape[1] == 1 and groups > 1 + + if is_depthwise: + # For depthwise: OIHW -> IHWO which gives [1, H, W, C_OUT] for CMSIS-NN + # PyTorch depthwise weight is [out_ch, 1, H, W], permute to [1, H, W, out_ch] + weight_permuted = weight_tensor.permute(1, 2, 3, 0).contiguous( + memory_format=torch.channels_last + ) + else: + # For regular conv: OIHW -> OHWI + weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( + memory_format=torch.channels_last + ) with node.graph.inserting_after(weight): weight_nhwc = create_constant_placeholder( @@ -187,21 +200,61 @@ def _get_convolution_replacement(self, node) -> int: torch.tensor(quantized_shifts, dtype=torch.int32), ) - new_args = ( - x, - weight_nhwc, - bias, - stride, - padding, - dilation, - -input_zero_point, - output_zero_point, - quantized_multiplier_tensor, - quantized_shift_tensor, - output_qmin, - output_qmax, - ) - return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + if is_depthwise: + # Compute depth_multiplier for depthwise convolution + # For depthwise: output_channels = input_channels * depth_multiplier + # PyTorch depthwise weight is [C_OUT, 1, H, W] + output_channels = weight_tensor.shape[0] + input_channels = groups # For depthwise, groups == input_channels + + # CMSIS-NN depthwise convolution only supports batch size of 1 + input_tensor = get_first_fake_tensor(x) + batch_size = input_tensor.shape[0] + if batch_size != 1: + raise ValueError( + f"Depthwise conv: CMSIS-NN only supports batch size 1, got {batch_size}" + ) + + if output_channels % input_channels != 0: + raise ValueError( + f"Depthwise conv: output_channels ({output_channels}) must be " + f"divisible by input_channels ({input_channels})" + ) + depth_multiplier = output_channels // input_channels + + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + depth_multiplier, + -input_zero_point, + output_zero_point, + quantized_multiplier_tensor, + quantized_shift_tensor, + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args + else: + # Use regular convolution operator + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + -input_zero_point, + output_zero_point, + quantized_multiplier_tensor, + quantized_shift_tensor, + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py index 5630abbdab3..b43505f3b16 100644 --- a/backends/cortex_m/test/ops/test_conv.py +++ b/backends/cortex_m/test/ops/test_conv.py @@ -112,6 +112,50 @@ def forward(self, x): return x +class CortexMDepthwiseConv2D(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_depthwise_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=False) + + def forward(self, x): + return self.conv(x) + + +class CortexMDepthwiseConv2DBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_depthwise_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=True) + + def forward(self, x): + return self.conv(x) + + # in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode test_cases = { "conv2d": McuTestCase( @@ -178,6 +222,73 @@ def forward(self, x): ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), ), ), + # Depthwise convolution tests (groups == in_channels) + "depthwise_conv2d": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 3, groups=4), + example_inputs=( + ramp_tensor(1, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_multiplier": McuTestCase( + model=CortexMDepthwiseConv2D(3, 6, 3, groups=3), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_stride": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 3, stride=2, groups=4), + example_inputs=( + ramp_tensor(-50, 50, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_padding": McuTestCase( + model=CortexMDepthwiseConv2D(2, 2, 5, padding=2, groups=2), + example_inputs=( + ramp_tensor(0, 1, (1, 2, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_bias": McuTestCase( + model=CortexMDepthwiseConv2DBias(3, 3, 3, padding=1, groups=3), + example_inputs=( + ramp_tensor(-10, 10, (1, 3, 6, 6)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_stride_padding_bias": McuTestCase( + model=CortexMDepthwiseConv2DBias(4, 4, 3, stride=2, padding=1, groups=4), + example_inputs=( + ramp_tensor(0, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_1x1": McuTestCase( + model=CortexMDepthwiseConv2D(4, 8, 1, groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_multiplier_4": McuTestCase( + model=CortexMDepthwiseConv2D(2, 8, 3, groups=2), + example_inputs=( + ramp_tensor(0, 10, (1, 2, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_asymmetric_kernel": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, (1, 3), groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_asymmetric_stride": McuTestCase( + model=CortexMDepthwiseConv2D(3, 3, 3, stride=(2, 1), padding=(1, 0), groups=3), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "depthwise_conv2d_5x5": McuTestCase( + model=CortexMDepthwiseConv2D(4, 4, 5, padding=2, groups=4), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), }