diff --git a/build.sh b/build.sh index f343a0b..ce01194 100755 --- a/build.sh +++ b/build.sh @@ -1,7 +1,7 @@ #!/bin/bash set -e -# rm -rf build +rm -rf build mkdir -p build cd build diff --git a/musa_ext/kernels/musa_softplus_op.cc b/musa_ext/kernels/musa_softplus_op.cc new file mode 100644 index 0000000..f5976c4 --- /dev/null +++ b/musa_ext/kernels/musa_softplus_op.cc @@ -0,0 +1,129 @@ +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "utils_op.h" + +namespace tensorflow { +namespace musa { + +template +class MusaSoftplusOp : public MusaOpKernel { + public: + explicit MusaSoftplusOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); + + if (input.NumElements() == 0) return; + + auto& handle = GetHandleByCtx(ctx); + + Tensor t_abs_tf, t_neg_abs_tf, t_exp_tf, t_exp_add1_tf, t_log_tf, t_relu_tf; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_abs_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_neg_abs_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_add1_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_log_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_relu_tf)); + + ::musa::dnn::Tensor x = CreateMTensor(input, format_); + ::musa::dnn::Tensor y = CreateMTensor(*output, format_); + ::musa::dnn::Tensor t_abs = CreateMTensor(t_abs_tf, format_); + ::musa::dnn::Tensor t_nabs = CreateMTensor(t_neg_abs_tf, format_); + ::musa::dnn::Tensor t_exp = CreateMTensor(t_exp_tf, format_); + ::musa::dnn::Tensor t_e1 = CreateMTensor(t_exp_add1_tf, format_); + ::musa::dnn::Tensor t_log = CreateMTensor(t_log_tf, format_); + ::musa::dnn::Tensor t_relu = CreateMTensor(t_relu_tf, format_); + + auto check_status = [&](::musa::dnn::Status status, const char* msg) { + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal(msg, " Status: ", static_cast(status))); + }; + + // 1) t_abs = abs(x) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::ABS), + "MUSA Softplus ABS SetMode failed."); + auto status = op.Run(handle, t_abs, x); + check_status(status, "MUSA Softplus ABS execution failed."); + } + + // 2) t_neg_abs = -t_abs (via unary MUL with alpha = -1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::MUL), + "MUSA Softplus MUL(SetMode) failed."); + check_status(op.SetAlpha(-1.0), + "MUSA Softplus MUL(SetAlpha=-1) failed."); + auto status = op.Run(handle, t_nabs, t_abs); + check_status(status, "MUSA Softplus MUL(-1) execution failed."); + } + + // 3) t_exp = exp(t_neg_abs) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::EXP), + "MUSA Softplus EXP SetMode failed."); + auto status = op.Run(handle, t_exp, t_nabs); + check_status(status, "MUSA Softplus EXP execution failed."); + } + + // 4) t_exp_add1 = t_exp + 1 (via unary ADD with alpha = 1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::ADD), + "MUSA Softplus ADD(SetMode) failed."); + check_status(op.SetAlpha(1.0), + "MUSA Softplus ADD(SetAlpha=1) failed."); + auto status = op.Run(handle, t_e1, t_exp); + check_status(status, "MUSA Softplus ADD(+1) execution failed."); + } + + // 5) t_log = log(t_exp_add1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::LOG), + "MUSA Softplus LOG SetMode failed."); + auto status = op.Run(handle, t_log, t_e1); + check_status(status, "MUSA Softplus LOG execution failed."); + } + + // 6) t_relu = relu(x) == max(x, 0) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::RELU), + "MUSA Softplus RELU SetMode failed."); + auto status = op.Run(handle, t_relu, x); + check_status(status, "MUSA Softplus RELU execution failed."); + } + + // 7) y = t_relu + t_log + { + ::musa::dnn::Binary op; + check_status(op.SetMode(::musa::dnn::Binary::Mode::ADD), + "MUSA Softplus Binary ADD SetMode failed."); + auto status = op.Run(handle, y, t_relu, t_log); + check_status(status, "MUSA Softplus final ADD execution failed."); + } + } +}; + +#define REGISTER_MUSA_SOFTPLUS(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softplus").Device("MUSA").TypeConstraint("T"), \ + MusaSoftplusOp) + +REGISTER_MUSA_SOFTPLUS(float); +REGISTER_MUSA_SOFTPLUS(Eigen::half); +REGISTER_MUSA_SOFTPLUS(bfloat16); + +#undef REGISTER_MUSA_SOFTPLUS + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/test/softplus_op_test.py b/test/softplus_op_test.py new file mode 100644 index 0000000..d3f9975 --- /dev/null +++ b/test/softplus_op_test.py @@ -0,0 +1,59 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MUSA Softplus operator.""" + +import numpy as np +import tensorflow as tf + +from musa_test_utils import MUSATestCase + + +class SoftplusOpTest(MUSATestCase): + + def _test_softplus(self, shape, dtype, rtol=1e-5, atol=1e-5): + np_dtype = dtype.as_numpy_dtype + if dtype == tf.bfloat16: + np_dtype = np.float32 + + if dtype == tf.float16: + low, high = -5.0, 5.0 + elif dtype == tf.bfloat16: + low, high = -3.0, 3.0 + else: + low, high = -10.0, 10.0 + + x_np = np.random.uniform(low, high, size=shape).astype(np_dtype) + x = tf.constant(x_np, dtype=dtype) + + self._compare_cpu_musa_results( + tf.math.softplus, + [x], + dtype, + rtol=rtol, + atol=atol + ) + def testSoftplusFloat32(self): + self._test_softplus([10, 10], tf.float32, rtol=1e-4, atol=1e-4) + + def testSoftplusFloat16(self): + self._test_softplus([2, 3, 4], tf.float16, rtol=1e-2, atol=1e-2) + + def testSoftplusBFloat16(self): + self._test_softplus([10, 10], tf.bfloat16, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file