Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
set -e

# rm -rf build
rm -rf build

mkdir -p build
cd build
Expand Down
129 changes: 129 additions & 0 deletions musa_ext/kernels/musa_softplus_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include <mudnn.h>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "utils_op.h"

namespace tensorflow {
namespace musa {

template <typename T>
class MusaSoftplusOp : public MusaOpKernel {
public:
explicit MusaSoftplusOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);

Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));

if (input.NumElements() == 0) return;

auto& handle = GetHandleByCtx(ctx);

Tensor t_abs_tf, t_neg_abs_tf, t_exp_tf, t_exp_add1_tf, t_log_tf, t_relu_tf;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_abs_tf));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_neg_abs_tf));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_tf));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_add1_tf));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_log_tf));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_relu_tf));

::musa::dnn::Tensor x = CreateMTensor(input, format_);
::musa::dnn::Tensor y = CreateMTensor(*output, format_);
::musa::dnn::Tensor t_abs = CreateMTensor(t_abs_tf, format_);
::musa::dnn::Tensor t_nabs = CreateMTensor(t_neg_abs_tf, format_);
::musa::dnn::Tensor t_exp = CreateMTensor(t_exp_tf, format_);
::musa::dnn::Tensor t_e1 = CreateMTensor(t_exp_add1_tf, format_);
::musa::dnn::Tensor t_log = CreateMTensor(t_log_tf, format_);
::musa::dnn::Tensor t_relu = CreateMTensor(t_relu_tf, format_);

auto check_status = [&](::musa::dnn::Status status, const char* msg) {
OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS,
errors::Internal(msg, " Status: ", static_cast<int>(status)));
};

// 1) t_abs = abs(x)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::ABS),
"MUSA Softplus ABS SetMode failed.");
auto status = op.Run(handle, t_abs, x);
check_status(status, "MUSA Softplus ABS execution failed.");
}

// 2) t_neg_abs = -t_abs (via unary MUL with alpha = -1)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::MUL),
"MUSA Softplus MUL(SetMode) failed.");
check_status(op.SetAlpha(-1.0),
"MUSA Softplus MUL(SetAlpha=-1) failed.");
auto status = op.Run(handle, t_nabs, t_abs);
check_status(status, "MUSA Softplus MUL(-1) execution failed.");
}

// 3) t_exp = exp(t_neg_abs)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::EXP),
"MUSA Softplus EXP SetMode failed.");
auto status = op.Run(handle, t_exp, t_nabs);
check_status(status, "MUSA Softplus EXP execution failed.");
}

// 4) t_exp_add1 = t_exp + 1 (via unary ADD with alpha = 1)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::ADD),
"MUSA Softplus ADD(SetMode) failed.");
check_status(op.SetAlpha(1.0),
"MUSA Softplus ADD(SetAlpha=1) failed.");
auto status = op.Run(handle, t_e1, t_exp);
check_status(status, "MUSA Softplus ADD(+1) execution failed.");
}

// 5) t_log = log(t_exp_add1)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::LOG),
"MUSA Softplus LOG SetMode failed.");
auto status = op.Run(handle, t_log, t_e1);
check_status(status, "MUSA Softplus LOG execution failed.");
}

// 6) t_relu = relu(x) == max(x, 0)
{
::musa::dnn::Unary op;
check_status(op.SetMode(::musa::dnn::Unary::Mode::RELU),
"MUSA Softplus RELU SetMode failed.");
auto status = op.Run(handle, t_relu, x);
check_status(status, "MUSA Softplus RELU execution failed.");
}

// 7) y = t_relu + t_log
{
::musa::dnn::Binary op;
check_status(op.SetMode(::musa::dnn::Binary::Mode::ADD),
"MUSA Softplus Binary ADD SetMode failed.");
auto status = op.Run(handle, y, t_relu, t_log);
check_status(status, "MUSA Softplus final ADD execution failed.");
}
}
};

#define REGISTER_MUSA_SOFTPLUS(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Softplus").Device("MUSA").TypeConstraint<TYPE>("T"), \
MusaSoftplusOp<TYPE>)

REGISTER_MUSA_SOFTPLUS(float);
REGISTER_MUSA_SOFTPLUS(Eigen::half);
REGISTER_MUSA_SOFTPLUS(bfloat16);

#undef REGISTER_MUSA_SOFTPLUS

} // namespace musa
} // namespace tensorflow
59 changes: 59 additions & 0 deletions test/softplus_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for MUSA Softplus operator."""

import numpy as np
import tensorflow as tf

from musa_test_utils import MUSATestCase


class SoftplusOpTest(MUSATestCase):

def _test_softplus(self, shape, dtype, rtol=1e-5, atol=1e-5):
np_dtype = dtype.as_numpy_dtype
if dtype == tf.bfloat16:
np_dtype = np.float32

if dtype == tf.float16:
low, high = -5.0, 5.0
elif dtype == tf.bfloat16:
low, high = -3.0, 3.0
else:
low, high = -10.0, 10.0

x_np = np.random.uniform(low, high, size=shape).astype(np_dtype)
x = tf.constant(x_np, dtype=dtype)

self._compare_cpu_musa_results(
tf.math.softplus,
[x],
dtype,
rtol=rtol,
atol=atol
)
def testSoftplusFloat32(self):
self._test_softplus([10, 10], tf.float32, rtol=1e-4, atol=1e-4)

def testSoftplusFloat16(self):
self._test_softplus([2, 3, 4], tf.float16, rtol=1e-2, atol=1e-2)

def testSoftplusBFloat16(self):
self._test_softplus([10, 10], tf.bfloat16, rtol=1e-1, atol=1e-1)


if __name__ == "__main__":
tf.test.main()