From 3d9b8006bce4eb57db92df1ab9e58a31e389b25e Mon Sep 17 00:00:00 2001 From: dijopaul Date: Tue, 11 Nov 2025 01:02:01 -0800 Subject: [PATCH] Adding fixes for im2row and conv1d --- backends/cadence/aot/functions_hifi.yaml | 15 + backends/cadence/hifi/kernels/CMakeLists.txt | 1 + backends/cadence/hifi/kernels/kernels.h | 22 ++ .../cadence/hifi/operators/CMakeLists.txt | 6 + .../cadence/hifi/operators/im2row_out.cpp | 370 ++++++++++++++++++ ...ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp | 113 ++++-- ...ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp | 74 ++-- ...nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp | 82 +++- ...nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp | 12 +- .../hifi/operators/op_transpose_copy.cpp | 166 ++++++++ .../hifi/third-party/nnlib/xa_nn_im2row.c | 106 +++++ 11 files changed, 884 insertions(+), 83 deletions(-) create mode 100644 backends/cadence/hifi/operators/im2row_out.cpp create mode 100644 backends/cadence/hifi/operators/op_transpose_copy.cpp create mode 100644 backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 3bdbb33d59b..aa00c26485d 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -267,6 +267,11 @@ - arg_meta: null kernel_name: impl::HiFi::tanh_out +- op: transpose_copy.int_out + kernels: + - arg_meta: null + kernel_name: impl::HiFi::transpose_copy_int_out + - op: view_copy.out kernels: - arg_meta: null @@ -278,6 +283,16 @@ kernel_name: impl::HiFi::where_self_out # custom ops +- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::im2row_out + +- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::im2row_per_tensor_out + - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 936e28e2241..c366cecbe0c 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -18,6 +18,7 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_fmod_broadcast_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_greater_lesser_equal_f32.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_logicalxor_bool_bool.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 08343e2528b..6a3dcd1d245 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -196,6 +196,28 @@ extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32( const unsigned char* __restrict__ p_condition, const WORD32* const p_condition_shape); +extern "C" WORD32 xa_nn_im2row_quantized( + const WORD8* __restrict__ data_im, + const WORD32 in_zero_point, + /* input parameters*/ + const WORD32 channels, + const WORD32 height, + const WORD32 width, + /* output parameters */ + const WORD32 out_height, + const WORD32 out_width, + /* convolution parameters */ + const WORD32 kernel_h, + const WORD32 kernel_w, + const WORD32 pad_h, + const WORD32 pad_w, + const WORD32 stride_h, + const WORD32 stride_w, + const WORD32 dilation_h, + const WORD32 dilation_w, + WORD8* __restrict__ data_col, + WORD32 channels_last); + extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32( FLOAT32* __restrict__ p_out, const WORD32* const p_out_shape, diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index 26555da9760..90e98031af8 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -16,6 +16,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) # ATen compliant ops that are needed to run this model. set(_aten_ops__srcs + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/im2row_out.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_atan2.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bitwise_and.cpp" @@ -52,6 +53,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_transpose_copy.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_view_copy.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" @@ -96,6 +98,10 @@ add_library( "op_quantize_per_tensor.cpp" "op_quantized_relu_out.cpp" "op_dequantize_per_tensor.cpp" + "op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out" + "op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out" + "op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out" + "op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out" "op_quantized_conv2d_nchw_out.cpp" "op_quantized_conv2d_nhwc_out.cpp" "op_quantized_fully_connected_out" diff --git a/backends/cadence/hifi/operators/im2row_out.cpp b/backends/cadence/hifi/operators/im2row_out.cpp new file mode 100644 index 00000000000..2e9f1814719 --- /dev/null +++ b/backends/cadence/hifi/operators/im2row_out.cpp @@ -0,0 +1,370 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace impl { +namespace HiFi { +namespace native { + +template +__attribute__((always_inline)) void im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensinos out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + T* __restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::memcpy(slice_col, slice_im, channels * sizeof(T)); + } else { + std::fill_n(slice_col, channels, T(in_zero_point)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + data_col[i_col * channels_col + c_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(_c * height + h_im) * width + w_im] + : static_cast(in_zero_point); + } + } + } + } + } + } +} + +void im2row_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + + bool optimized = false; + if (input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Byte) + optimized = true; + + if (optimized) { + const int8_t* __restrict__ in_data = + (WORD8* __restrict__)input.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + const int32_t* __restrict__ zero_point = + in_zero_point.const_data_ptr(); + int32_t in_plane = in_c * in_h * in_w; + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; + for (size_t n = 0; n < batch_size; ++n) { + xa_nn_im2row_quantized( + &in_data[n * in_plane], + per_tensor_quantized ? zero_point[0] : zero_point[n], + in_c, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + &out_data[n * out_plane], + channel_last ? 1 : 0); + } + } else { +#define typed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row(Float, float); + typed_im2row(Byte, uint8_t); + typed_im2row(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row + } +} + +void im2row_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + int64_t in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + bool optimized = false; + if (input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Byte) + optimized = true; + + if (optimized) { + const int8_t* __restrict__ in_data = + (WORD8* __restrict__)input.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + int32_t in_plane = in_c * in_h * in_w; + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; + + for (size_t n = 0; n < batch_size; ++n) { + xa_nn_im2row_quantized( + &in_data[n * in_plane], + (int32_t)in_zero_point, + in_c, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + &out_data[n * out_plane], + channel_last ? 1 : 0); + } + } else { +#define typed_im2row_per_tensor(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + in_zero_point, \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row_per_tensor(Float, float); + typed_im2row_per_tensor(Byte, uint8_t); + typed_im2row_per_tensor(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row.per_tensor not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row_per_tensor + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp index b5ab0cdbaa2..f543f4633cf 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -35,7 +35,7 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( float output_scale, int32_t output_zero_point, Tensor& out) { - constexpr int kNnlibMaxDim = 3; + constexpr int kNnlibMaxDim = 5; WORD8* __restrict__ p_out = (WORD8* __restrict__)out.mutable_data_ptr(); @@ -49,19 +49,29 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( WORD32 batches = input.size(0); WORD32 input_channels = input.size(1); WORD32 input_width = input.size(2); + WORD32 input_height = 1; + WORD32 kernel_height = 1; WORD32 out_channels = weight.size(0); WORD32 kernel_channels = weight.size(1); WORD32 kernel_width = weight.size(2); WORD32 out_width = out.size(2); + WORD32 out_height = 1; WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_height = 1; + WORD32 dilation_width = 1; WORD32 input_zero_bias = -in_zero_point; - WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; - WORD32 out_shift32 = 0; WORD32 kernel_zero_bias = -weight_zero_point; WORD32 out_zero_bias = output_zero_point; + + WORD32 input_precision = 8; + WORD32 kernel_precision = 8; + WORD32 out_data_format = 1; + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( ctx, ((batches * input_channels * input_width) + 8) * sizeof(WORD8)); WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( @@ -71,16 +81,20 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = input_channels; - p_inp_shape[2] = input_width; + p_inp_shape[0] = 1; + p_inp_shape[1] = 1; + p_inp_shape[2] = batches; + p_inp_shape[3] = input_channels; + p_inp_shape[4] = input_width; WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = input_width; - p_out_shape[2] = input_channels; + p_out_shape[0] = 1; + p_out_shape[1] = 1; + p_out_shape[2] = batches; + p_out_shape[3] = input_width; + p_out_shape[4] = input_channels; - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 1, 2, 4, 3}; xa_nn_transpose_8_8( pin, @@ -92,14 +106,18 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( kNnlibMaxDim); WORD32 p_inp_shape1[kNnlibMaxDim]; - p_inp_shape1[0] = out_channels; - p_inp_shape1[1] = kernel_channels; - p_inp_shape1[2] = kernel_width; + p_inp_shape1[0] = 1; + p_inp_shape1[1] = 1; + p_inp_shape1[2] = out_channels; + p_inp_shape1[3] = kernel_channels; + p_inp_shape1[4] = kernel_width; WORD32 p_out_shape1[kNnlibMaxDim]; - p_out_shape1[0] = out_channels; - p_out_shape1[1] = kernel_width; - p_out_shape1[2] = kernel_channels; + p_out_shape1[0] = 1; + p_out_shape1[1] = 1; + p_out_shape1[2] = out_channels; + p_out_shape1[3] = kernel_width; + p_out_shape1[4] = kernel_channels; xa_nn_transpose_8_8( pkernel, @@ -110,34 +128,71 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( kNnlibMaxDim, kNnlibMaxDim); - WORD32 scratch_size = - xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + WORD32 p_out_multiplier32[out_channels]; + WORD32 p_out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + p_out_multiplier32[i] = bias_scale * out_scale * 2147483648; + p_out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + x_stride, + x_padding, + y_stride, + y_padding, + out_height, + out_width, + out_channels, + input_precision, + kernel_precision, + out_data_format); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + pVOID p_scratch = nullptr; WORD32* ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = pin + _n * input_channels * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_width; + WORD8* in_batch = pin + _n * input_channels * 1 * input_width; + WORD8* out_batch = p_out + _n * out_channels * 1 * out_width; - xa_nn_conv1d_std_asym8xasym8( - (UWORD8*)out_batch, - (UWORD8*)in_batch, - (UWORD8*)pkernel, + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + pkernel, p_bias, - 1, + input_height, input_width, input_channels, + kernel_height, kernel_width, + kernel_channels, + dilation_height, + dilation_width, out_channels, x_stride, + y_stride, x_padding, + y_padding, + out_height, out_width, input_zero_bias, - kernel_zero_bias, - out_multiplier32, - out_shift32, + p_out_multiplier32, + p_out_shift32, out_zero_bias, out_data_format, p_scratch); diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp index 60e700f563b..4ad36a3b5fa 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -35,7 +35,7 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( float output_scale, int32_t output_zero_point, Tensor& out) { - constexpr int kNnlibMaxDim = 3; + constexpr int kNnlibMaxDim = 5; UWORD8* __restrict__ p_out = (UWORD8* __restrict__)out.mutable_data_ptr(); @@ -49,10 +49,13 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( WORD32 batches = input.size(0); WORD32 input_channels = input.size(1); WORD32 input_width = input.size(2); + WORD32 input_height = 1; + WORD32 kernel_height = 1; WORD32 out_channels = weight.size(0); WORD32 kernel_channels = weight.size(1); WORD32 kernel_width = weight.size(2); WORD32 out_width = out.size(2); + WORD32 out_height = 1; WORD32 x_stride = stride[1]; WORD32 x_padding = padding[1]; WORD32 input_zero_bias = -in_zero_point; @@ -62,25 +65,37 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( WORD32 out_zero_bias = output_zero_point; WORD32 out_data_format = 1; - UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( - ctx, ((batches * input_channels * input_width) + 8) * sizeof(UWORD8)); - UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( + + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, ((batches * input_channels * input_width) + 8) * sizeof(WORD8)); + WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( ctx, - ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(UWORD8)); - UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); - UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); + ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(WORD8)); + WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = input_channels; - p_inp_shape[2] = input_width; + p_inp_shape[0] = 1; + p_inp_shape[1] = 1; + p_inp_shape[2] = batches; + p_inp_shape[3] = input_channels; + p_inp_shape[4] = input_width; WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = input_width; - p_out_shape[2] = input_channels; + p_out_shape[0] = 1; + p_out_shape[1] = 1; + p_out_shape[2] = batches; + p_out_shape[3] = input_width; + p_out_shape[4] = input_channels; - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 1, 2, 4, 3}; xa_nn_transpose_8_8( (WORD8*)pin, @@ -92,14 +107,18 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( kNnlibMaxDim); WORD32 p_inp_shape1[kNnlibMaxDim]; - p_inp_shape1[0] = out_channels; - p_inp_shape1[1] = kernel_channels; - p_inp_shape1[2] = kernel_width; + p_inp_shape1[0] = 1; + p_inp_shape1[1] = 1; + p_inp_shape1[2] = out_channels; + p_inp_shape1[3] = kernel_channels; + p_inp_shape1[4] = kernel_width; WORD32 p_out_shape1[kNnlibMaxDim]; - p_out_shape1[0] = out_channels; - p_out_shape1[1] = kernel_width; - p_out_shape1[2] = kernel_channels; + p_out_shape1[0] = 1; + p_out_shape1[1] = 1; + p_out_shape1[2] = out_channels; + p_out_shape1[3] = kernel_width; + p_out_shape1[4] = kernel_channels; xa_nn_transpose_8_8( (WORD8*)pkernel, @@ -110,24 +129,17 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( kNnlibMaxDim, kNnlibMaxDim); - WORD32 scratch_size = - xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); - scratch_size = scratch_size < 0 ? 0 : scratch_size; - WORD32* ptr_scratch = - (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = pin + _n * input_channels * input_width; - UWORD8* out_batch = p_out + _n * out_channels * out_width; + UWORD8* in_batch = (UWORD8*)(pin + _n * input_channels * input_width); + UWORD8* out_batch = (UWORD8*)(p_out + _n * out_channels * out_width); xa_nn_conv1d_std_asym8uxasym8u( out_batch, in_batch, - pkernel, + (UWORD8*)pkernel, p_bias, - 1, input_width, + input_height, input_channels, kernel_width, out_channels, diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp index c9a3d2b58de..3b1c7b9a900 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -47,46 +47,94 @@ void xa_opt_quantized_conv1d_nlc_asym8sxsym8s_asym8s( WORD32 batches = input.size(0); WORD32 input_channels = input.size(1); WORD32 input_width = input.size(2); + WORD32 input_height = 1; + WORD32 kernel_height = 1; WORD32 out_channels = weight.size(0); + WORD32 kernel_channels = weight.size(1); WORD32 kernel_width = weight.size(2); WORD32 out_width = out.size(2); + WORD32 out_height = 1; WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + WORD32 dilation_height = 1; + WORD32 dilation_width = 1; WORD32 input_zero_bias = -in_zero_point; - WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; - WORD32 out_shift32 = 0; WORD32 kernel_zero_bias = -weight_zero_point; + WORD32 input_precision = 8; + WORD32 kernel_precision = 8; + WORD32 out_zero_bias = output_zero_point; - WORD32 out_data_format = 0; - WORD32 scratch_size = - xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + + WORD32 out_data_format = 1; + + WORD32 p_out_multiplier32[out_channels]; + WORD32 p_out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + p_out_multiplier32[i] = bias_scale * out_scale * 2147483648; + p_out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + kernel_channels, + dilation_height, + dilation_width, + y_stride, + y_padding, + x_stride, + x_padding, + out_height, + out_width, + out_channels, + input_precision, + kernel_precision, + out_data_format); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + pVOID p_scratch = nullptr; WORD32* ptr_scratch = - (WORD32*)::impl::HiFi::kernels::allocate_temp_memory(ctx, scratch_size); - pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + + p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = p_inp + _n * input_channels * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_width; + WORD8* in_batch = p_inp + _n * input_channels * 1 * input_width; + WORD8* out_batch = p_out + _n * out_channels * 1 * out_width; - xa_nn_conv1d_std_asym8xasym8( - (UWORD8*)out_batch, - (UWORD8*)in_batch, - (UWORD8*)p_kernel, + xa_nn_conv2d_per_chan_sym8sxasym8s( + out_batch, + in_batch, + p_kernel, p_bias, - 1, + input_height, input_width, input_channels, + kernel_height, kernel_width, + kernel_channels, + dilation_height, + dilation_width, out_channels, x_stride, + y_stride, x_padding, + y_padding, + out_height, out_width, input_zero_bias, - kernel_zero_bias, - out_multiplier32, - out_shift32, + p_out_multiplier32, + p_out_shift32, out_zero_bias, out_data_format, p_scratch); diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp index 2d7a4cba509..5539410f46e 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -45,11 +45,11 @@ void xa_opt_quantized_conv1d_nlc_asym8uxsym8u_asym8u( (WORD32* __restrict__)bias.const_data_ptr(); WORD32 batches = input.size(0); - WORD32 input_channels = input.size(1); - WORD32 input_width = input.size(2); - WORD32 out_channels = weight.size(0); - WORD32 kernel_width = weight.size(2); - WORD32 out_width = out.size(2); + WORD32 input_channels = input.size(2); + WORD32 input_width = input.size(1); + WORD32 out_channels = weight.size(2); + WORD32 kernel_width = weight.size(1); + WORD32 out_width = out.size(1); WORD32 x_stride = stride[1]; WORD32 x_padding = padding[1]; WORD32 input_zero_bias = -in_zero_point; @@ -75,8 +75,8 @@ void xa_opt_quantized_conv1d_nlc_asym8uxsym8u_asym8u( in_batch, p_kernel, p_bias, - 1, input_width, + 1, input_channels, kernel_width, out_channels, diff --git a/backends/cadence/hifi/operators/op_transpose_copy.cpp b/backends/cadence/hifi/operators/op_transpose_copy.cpp new file mode 100644 index 00000000000..a21a7f6178c --- /dev/null +++ b/backends/cadence/hifi/operators/op_transpose_copy.cpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::SizesType; +using executorch::aten::Tensor; +using executorch::runtime::Error; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::kTensorDimensionLimit; +using executorch::runtime::nonzero_dim; +using executorch::runtime::resize_tensor; +using executorch::runtime::tensors_have_same_dim_order; +using torch::executor::check_transpose_copy_args; +using torch::executor::get_transpose_out_target_size; +using torch::executor::transpose_tensors; + +namespace impl { +namespace HiFi { +namespace native { + +/** + * Swaps dimension 'dim0' of 'a' with 'dim1', and copying + * that mutation into `out` in a manner such that the data is densely packed + * and is_contiguous() would return true (stride dim[size-1] = 1). + * + * transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) + */ +Tensor& transpose_copy_int_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim0, + int64_t dim1, + Tensor& out) { + (void)ctx; + + if (dim0 < 0) { + dim0 += nonzero_dim(in); + } + if (dim1 < 0) { + dim1 += nonzero_dim(in); + } + + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + get_transpose_out_target_size( + in, dim0, dim1, expected_out_size, &expected_out_dim); + + // Resize for dynamic shape + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + const auto in_type = out.scalar_type(); + constexpr int kNnlibMaxDim = 5; + + bool optimized = false; + + if (out.scalar_type() == ScalarType::Float || + out.scalar_type() == ScalarType::Char || + out.scalar_type() == ScalarType::Byte) + optimized = true; + + if (in.dim() > kNnlibMaxDim) + optimized = false; + + if (optimized) { + WORD32 num_inp_dims = in.dim(); + WORD32 num_out_dims = num_inp_dims; + + WORD32 p_inp_shape[kNnlibMaxDim]; + WORD32 p_out_shape[kNnlibMaxDim]; + WORD32 p_permute_vec[kNnlibMaxDim]; + + for (int i = 0; i < in.dim(); i++) { + p_inp_shape[i] = in.size(i); + } + for (int i = 0; i < out.dim(); i++) { + p_out_shape[i] = out.size(i); + } + + for (int i = 0; i < in.dim(); i++) { + p_permute_vec[i] = i; + } + + p_permute_vec[dim0] = dim1; + p_permute_vec[dim1] = dim0; + + if (in_type == ScalarType::Float) { + WORD32* p_inp = (WORD32*)in.const_data_ptr(); + WORD32* p_out = (WORD32*)out.mutable_data_ptr(); + + WORD32 ret_val = xa_nn_transpose_32_32( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + + ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out); + + } else if (in_type == ScalarType::Char) { + WORD8* p_inp = (WORD8*)in.const_data_ptr(); + WORD8* p_out = (WORD8*)out.mutable_data_ptr(); + + WORD32 val = xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + + ET_KERNEL_CHECK(ctx, val == 0, Internal, out); + + } else if (in_type == ScalarType::Byte) { + WORD8* p_inp = (WORD8*)in.const_data_ptr(); + WORD8* p_out = (WORD8*)out.mutable_data_ptr(); + + WORD32 val = xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + + ET_KERNEL_CHECK(ctx, val == 0, Internal, out); + } + + return out; + } + + ET_KERNEL_CHECK( + ctx, + check_transpose_copy_args(in, dim0, dim1, out), + InvalidArgument, + out); + + ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { + transpose_tensors(in, dim0, dim1, out); + }); + + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c new file mode 100644 index 00000000000..7008ee58f0a --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c @@ -0,0 +1,106 @@ +#include "xa_nn_common.h" +#include "xa_nnlib_common_fpu.h" +#include "xa_nnlib_err_chk.h" +#include "xa_type_def.h" +// #include "xa_nn_basic_state.h" +#include "xa_nnlib_kernels_api.h" + +WORD32 xa_nn_im2row_quantized( + const WORD8 *__restrict__ data_im, const WORD32 in_zero_point, + /* input parameters*/ + const WORD32 channels, const WORD32 height, const WORD32 width, + /* output parameters */ + const WORD32 out_height, const WORD32 out_width, + /* convolution parameters */ + const WORD32 kernel_h, const WORD32 kernel_w, const WORD32 pad_h, + const WORD32 pad_w, const WORD32 stride_h, const WORD32 stride_w, + const WORD32 dilation_h, const WORD32 dilation_w, + WORD8 *__restrict__ data_col, WORD32 channels_last) { + const WORD32 channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const int8_t *__restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + int8_t *__restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + const ae_int32x2 *pae_inp = (const ae_int32x2 *)slice_im; + ae_int32x2 *pae_out = (ae_int32x2 *)slice_col; + ae_valign inp_a, out_a; + inp_a = AE_LA64_PP(pae_inp); + out_a = AE_ZALIGN64(); + + ae_int32x2 d0; + for (int ic = 0; ic < channels >> 3; ic++) { + AE_LA32X2_IP(d0, inp_a, pae_inp); + AE_SA32X2_IP(d0, out_a, pae_out); + } + AE_SA64POS_FP(out_a, pae_out); + + int remainder = channels & 7; + int8_t *ptmp_in = (int8_t *)pae_inp; + int8_t *ptmp_out = (int8_t *)pae_out; + for (int ic = 0; ic < remainder; ic++) { + *ptmp_out++ = *ptmp_in++; + } + } else { + for (int i = 0; i < channels; i++) { + slice_col[i] = (int8_t)(in_zero_point); + } + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + data_col[i_col * channels_col + c_col] = + data_im[(_c * height + h_im) * width + w_im]; + else + data_col[i_col * channels_col + c_col] = (int8_t)in_zero_point; + } + } + } + } + } + } + + return 0; +}