Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions backends/cadence/generic/operators/quantized_op_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
#pragma once

#include <executorch/backends/cadence/generic/operators/cadence_type_util.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

// Generate kernels that perform elementwise arithmetic on two quantized
// tensors. The tensors are either the same size, or the second tensor is a
// scalar.
#define DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \
template <typename T> \
void BINARY_FUNC_NAME( \
Expand All @@ -26,17 +24,19 @@
float out_scale, \
int32_t out_zero_point, \
::executorch::aten::Tensor& out) { \
const T* __restrict__ X_data = X.const_data_ptr<T>(); \
const T* __restrict__ Y_data = Y.const_data_ptr<T>(); \
T* __restrict__ out_data = out.mutable_data_ptr<T>(); \
float inv_out_scale = 1.0f / out_scale; \
for (size_t i = 0, e = X.numel(); i < e; ++i) { \
float x = ::impl::generic::kernels::dequantize<T>( \
X_data[i], X_scale, X_zero_point); \
float y = ::impl::generic::kernels::dequantize<T>( \
Y_data[i], Y_scale, Y_zero_point); \
float z = x OP y; \
out_data[i] = ::impl::generic::kernels::quantize<T>( \
z, inv_out_scale, out_zero_point); \
} \
::torch::executor::apply_binary_elementwise_fn<T, T, T>( \
[X_scale, X_zero_point, Y_scale, Y_zero_point, inv_out_scale, \
out_zero_point](const T x_val, const T y_val) { \
float x = ::impl::generic::kernels::dequantize<T>( \
x_val, X_scale, X_zero_point); \
float y = ::impl::generic::kernels::dequantize<T>( \
y_val, Y_scale, Y_zero_point); \
float z = x OP y; \
return ::impl::generic::kernels::quantize<T>( \
z, inv_out_scale, out_zero_point); \
}, \
X, \
Y, \
out); \
}
1 change: 1 addition & 0 deletions backends/cadence/generic/operators/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def define_common_targets():
exported_headers = ["quantized_op_macros.h"],
exported_deps = [
":cadence_type_util",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/runtime/kernel:kernel_includes",
]
)
Expand Down
Loading