From 19387c1ac689762a60458ff78941bc3863e4f051 Mon Sep 17 00:00:00 2001 From: tianhang Date: Thu, 26 Feb 2026 20:01:19 +0800 Subject: [PATCH 1/2] softplus unfinished --- musa_ext/kernels/musa_softplus_kernel.mu | 99 ++++++++++ musa_ext/kernels/musa_softplus_op.cc | 121 ++++++++++++ test/softplus_op_test.py | 223 +++++++++++++++++++++++ 3 files changed, 443 insertions(+) create mode 100644 musa_ext/kernels/musa_softplus_kernel.mu create mode 100644 musa_ext/kernels/musa_softplus_op.cc create mode 100644 test/softplus_op_test.py diff --git a/musa_ext/kernels/musa_softplus_kernel.mu b/musa_ext/kernels/musa_softplus_kernel.mu new file mode 100644 index 0000000..05bb047 --- /dev/null +++ b/musa_ext/kernels/musa_softplus_kernel.mu @@ -0,0 +1,99 @@ +// MUSA Softplus Custom Kernel +// Performs element-wise softplus: softplus(x) = log(1 + exp(x)) +// Numerically stable form: max(x, 0) + log1p(exp(-abs(x))) +// +// Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0. + +#include +#include +#include +#include + +extern "C" { + +// ---------------------------------------------------------------------------- +// Numerically stable softplus helpers +// ---------------------------------------------------------------------------- + +__device__ __forceinline__ float SoftplusStable(float x) { + float ax = fabsf(x); + float mx = x > 0.0f ? x : 0.0f; + return mx + log1pf(expf(-ax)); +} + +__device__ __forceinline__ double SoftplusStable(double x) { + double ax = fabs(x); + double mx = x > 0.0 ? x : 0.0; + return mx + log1p(exp(-ax)); +} + +// ---------------------------------------------------------------------------- +// Float kernel +// ---------------------------------------------------------------------------- + +__global__ void SoftplusKernelFloat(const float* input, float* output, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + output[idx] = SoftplusStable(input[idx]); + } +} + +// ---------------------------------------------------------------------------- +// Double kernel +// ---------------------------------------------------------------------------- + +__global__ void SoftplusKernelDouble(const double* input, double* output, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + output[idx] = SoftplusStable(input[idx]); + } +} + +// ---------------------------------------------------------------------------- +// Half (float16) kernel - compute in float for stability/precision +// ---------------------------------------------------------------------------- + +__global__ void SoftplusKernelHalf(const half* input, half* output, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + float x = __half2float(input[idx]); + float y = SoftplusStable(x); + output[idx] = __float2half(y); + } +} + +// ---------------------------------------------------------------------------- +// BFloat16 kernel - compute in float for stability/precision +// ---------------------------------------------------------------------------- + +__global__ void SoftplusKernelBFloat16(const __mt_bfloat16* input, + __mt_bfloat16* output, + int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + float x = __bfloat162float(input[idx]); + float y = SoftplusStable(x); + output[idx] = __float2bfloat16(y); + } +} + +// ---------------------------------------------------------------------------- +// Launcher functions - called from C++ code +// ---------------------------------------------------------------------------- + +#define DEFINE_SOFTPLUS_LAUNCHER(name, kernel, T) \ + void name(const T* input, T* output, int size, musaStream_t stream) { \ + const int threads_per_block = 256; \ + const int blocks = (size + threads_per_block - 1) / threads_per_block; \ + kernel<<>>(input, output, size); \ + } + +DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelFloat, SoftplusKernelFloat, float) +DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelDouble, SoftplusKernelDouble, double) +DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelHalf, SoftplusKernelHalf, half) +DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelBFloat16, SoftplusKernelBFloat16, __mt_bfloat16) + +#undef DEFINE_SOFTPLUS_LAUNCHER + +} // extern "C" \ No newline at end of file diff --git a/musa_ext/kernels/musa_softplus_op.cc b/musa_ext/kernels/musa_softplus_op.cc new file mode 100644 index 0000000..1eb5f01 --- /dev/null +++ b/musa_ext/kernels/musa_softplus_op.cc @@ -0,0 +1,121 @@ +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "utils_op.h" + +// ============================================================================ +// MUSA Softplus custom kernel launcher declarations from +// musa_softplus_kernel.mu +// ============================================================================ + +extern "C" { +void LaunchSoftplusKernelFloat(const float* input, float* output, int size, + musaStream_t stream); +void LaunchSoftplusKernelDouble(const double* input, double* output, int size, + musaStream_t stream); +void LaunchSoftplusKernelHalf(const void* input, void* output, int size, + musaStream_t stream); +void LaunchSoftplusKernelBFloat16(const void* input, void* output, int size, + musaStream_t stream); +} + +namespace tensorflow { +namespace musa { + +// ============================================================================ +// Common implementation for Softplus Compute +// ============================================================================ + +template +void SoftplusCompute(OpKernelContext* ctx, + void (*launcher)(const T*, T*, int, musaStream_t)) { + OP_REQUIRES(ctx, ctx->num_inputs() == 1, + errors::InvalidArgument("Softplus expects 1 input, got ", + ctx->num_inputs())); + + const Tensor& input = ctx->input(0); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); + + const int64 num_elements = input.NumElements(); + if (num_elements == 0) return; + + musaStream_t stream = + reinterpret_cast(GetHandleByCtx(ctx).GetStream()); + + const void* input_ptr = input.tensor_data().data(); + void* output_ptr = const_cast(output->tensor_data().data()); + + launcher(reinterpret_cast(input_ptr), + reinterpret_cast(output_ptr), static_cast(num_elements), + stream); +} + +// ============================================================================ +// Softplus operator class +// ============================================================================ + +template +class MusaSoftplusOp : public MusaOpKernel { + public: + explicit MusaSoftplusOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SoftplusCompute(ctx, GetLauncher()); + } + + private: + static void (*GetLauncher())(const T*, T*, int, musaStream_t); +}; + +// ============================================================================ +// Launcher function getters - specialized for each type +// ============================================================================ + +#define DEFINE_SOFTPLUS_LAUNCHER_GETTER(T, launcher, input_cast, output_cast) \ + template <> \ + void (*MusaSoftplusOp::GetLauncher())(const T*, T*, int, musaStream_t) { \ + return [](const T* input, T* output, int size, musaStream_t stream) { \ + launcher(input_cast(input), output_cast(output), size, stream); \ + }; \ + } + +// Float / double +DEFINE_SOFTPLUS_LAUNCHER_GETTER(float, LaunchSoftplusKernelFloat, + reinterpret_cast, + reinterpret_cast) + +DEFINE_SOFTPLUS_LAUNCHER_GETTER(double, LaunchSoftplusKernelDouble, + reinterpret_cast, + reinterpret_cast) + +// Half / BFloat16 use void* bridge +DEFINE_SOFTPLUS_LAUNCHER_GETTER(Eigen::half, LaunchSoftplusKernelHalf, + reinterpret_cast, + reinterpret_cast) + +DEFINE_SOFTPLUS_LAUNCHER_GETTER(bfloat16, LaunchSoftplusKernelBFloat16, + reinterpret_cast, + reinterpret_cast) + +#undef DEFINE_SOFTPLUS_LAUNCHER_GETTER + +// ============================================================================ +// Kernel registration (TF2 only) +// Name must match TensorFlow official op: "Softplus" +// ============================================================================ + +#define REGISTER_MUSA_SOFTPLUS(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softplus").Device("MUSA").TypeConstraint("T"), \ + MusaSoftplusOp); + +REGISTER_MUSA_SOFTPLUS(float); +REGISTER_MUSA_SOFTPLUS(double); +// REGISTER_MUSA_SOFTPLUS(Eigen::half); +// REGISTER_MUSA_SOFTPLUS(bfloat16); + +#undef REGISTER_MUSA_SOFTPLUS + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/test/softplus_op_test.py b/test/softplus_op_test.py new file mode 100644 index 0000000..02b8d60 --- /dev/null +++ b/test/softplus_op_test.py @@ -0,0 +1,223 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MUSA Softplus operator.""" + +import numpy as np +import tensorflow as tf + +from musa_test_utils import MUSATestCase + + +dtype_list = [tf.float32] +# dtype_list = [tf.float32, tf.float16, tf.bfloat16, tf.float64] +class SoftplusOpTest(MUSATestCase): + """Tests for MUSA Softplus operator.""" + + + def _test_softplus(self, shape, dtype, rtol=1e-5, atol=1e-8, + value_range=(-10.0, 10.0)): + """Test softplus operation with given shape and dtype.""" + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + + # Generate input tensor + input_np = np.random.uniform(value_range[0], value_range[1], size=shape).astype(np_dtype) + input_tf = tf.constant(input_np, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusBasic(self): + """Test Softplus with basic shapes.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + self._test_softplus([10], dtype, rtol=rtol, atol=atol) + self._test_softplus([128], dtype, rtol=rtol, atol=atol) + self._test_softplus([32, 64], dtype, rtol=rtol, atol=atol) + self._test_softplus([8, 16, 32], dtype, rtol=rtol, atol=atol) + + def testSoftplusLargeTensor(self): + """Test Softplus with large tensors.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + self._test_softplus([2048, 2048], dtype, rtol=rtol, atol=atol) + + def testSoftplusEmptyTensor(self): + """Test Softplus with empty tensors.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + # Empty tensor with shape [0] + self._test_softplus([0], dtype, rtol=rtol, atol=atol) + # Empty tensor with shape [0, 5] + self._test_softplus([0, 5], dtype, rtol=rtol, atol=atol) + + def testSoftplusZeroValues(self): + """Test Softplus with all-zero values.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + input_np = np.zeros((256,), dtype=np_dtype) + input_tf = tf.constant(input_np, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusMixedPositiveNegative(self): + """Test Softplus with mixed positive and negative values.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + input_np = np.random.uniform(-50, 50, size=(512,)).astype(np_dtype) + input_tf = tf.constant(input_np, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusExtremeValues(self): + """Test Softplus with extreme values for numerical stability.""" + for dtype in dtype_list: + rtol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + + # Use a conservative range for fp16/bf16 to reduce overflow-related mismatch + if dtype == tf.float16: + vals = np.array([-20.0, -10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0, 15.0], + dtype=np_dtype) + elif dtype == tf.bfloat16: + vals = np.array([-50.0, -20.0, -10.0, -1.0, 0.0, 1.0, 10.0, 20.0, 50.0], + dtype=np_dtype) + elif dtype == tf.float32: + vals = np.array([-100.0, -50.0, -20.0, -1.0, 0.0, 1.0, 20.0, 50.0, 100.0], + dtype=np_dtype) + else: # tf.float64 + vals = np.array([-500.0, -100.0, -20.0, -1.0, 0.0, 1.0, 20.0, 100.0, 500.0], + dtype=np_dtype) + + input_tf = tf.constant(vals, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusSmallValues(self): + """Test Softplus with very small values around zero.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-6 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + + if dtype == tf.float16: + eps = np.finfo(np.float16).eps + elif dtype == tf.float64: + eps = np.finfo(np.float64).eps + else: + eps = np.finfo(np.float32).eps + + vals = np.array([-10 * eps, -eps, 0.0, eps, 10 * eps], dtype=np_dtype) + input_tf = tf.constant(vals, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusScalar(self): + """Test Softplus with scalar input.""" + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + input_np = np.array(1.2345, dtype=np_dtype) + input_tf = tf.constant(input_np, dtype=dtype) + + def softplus_wrapper(x): + return tf.nn.softplus(x) + + self._compare_cpu_musa_results( + softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) + + def testSoftplusDifferentShapes(self): + """Test Softplus on tensors with various dimensions.""" + test_shapes = [ + [1], + [7], + [1, 1], + [3, 5], + [2, 3, 4], + [2, 2, 2, 2], + ] + for dtype in dtype_list: + rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 + atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + + for shape in test_shapes: + self._test_softplus(shape, dtype, rtol=rtol, atol=atol) + + def testSoftplusMonotonicitySanity(self): + """Sanity test: softplus should be monotonic increasing.""" + # This checks both functional correctness and consistency on MUSA. + for dtype in dtype_list: + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + x_np = np.linspace(-10.0, 10.0, num=1000).astype(np_dtype) + x_tf = tf.constant(x_np, dtype=dtype) + + with tf.device('/device:MUSA:0'): + y = tf.nn.softplus(x_tf) + + y_np = tf.cast(y, tf.float32).numpy() if dtype in [tf.float16, tf.bfloat16] else y.numpy() + diffs = np.diff(y_np) + + # Allow tiny numerical noise for low precision + tol = 1e-4 if dtype in [tf.float16, tf.bfloat16] else 1e-8 + self.assertTrue(np.all(diffs >= -tol), + msg="Softplus output is not monotonic increasing.") + + def testSoftplusInvalidType(self): + """Test Softplus with invalid (non-floating) dtype should fail.""" + for dtype in [tf.int32, tf.int64]: + x = tf.constant([1, 2, 3], dtype=dtype) + with self.assertRaises((tf.errors.InvalidArgumentError, TypeError)): + with tf.device('/device:MUSA:0'): + _ = tf.nn.softplus(x) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file From 940943bdf7e8300b8ffc90d269a92e8276cfc8a3 Mon Sep 17 00:00:00 2001 From: tianhang Date: Fri, 27 Feb 2026 17:41:07 +0800 Subject: [PATCH 2/2] feat: add new op softplus --- build.sh | 2 +- musa_ext/kernels/musa_softplus_kernel.mu | 99 ----------- musa_ext/kernels/musa_softplus_op.cc | 204 +++++++++++---------- test/softplus_op_test.py | 216 +++-------------------- 4 files changed, 133 insertions(+), 388 deletions(-) delete mode 100644 musa_ext/kernels/musa_softplus_kernel.mu diff --git a/build.sh b/build.sh index f343a0b..ce01194 100755 --- a/build.sh +++ b/build.sh @@ -1,7 +1,7 @@ #!/bin/bash set -e -# rm -rf build +rm -rf build mkdir -p build cd build diff --git a/musa_ext/kernels/musa_softplus_kernel.mu b/musa_ext/kernels/musa_softplus_kernel.mu deleted file mode 100644 index 05bb047..0000000 --- a/musa_ext/kernels/musa_softplus_kernel.mu +++ /dev/null @@ -1,99 +0,0 @@ -// MUSA Softplus Custom Kernel -// Performs element-wise softplus: softplus(x) = log(1 + exp(x)) -// Numerically stable form: max(x, 0) + log1p(exp(-abs(x))) -// -// Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. -// Licensed under the Apache License, Version 2.0. - -#include -#include -#include -#include - -extern "C" { - -// ---------------------------------------------------------------------------- -// Numerically stable softplus helpers -// ---------------------------------------------------------------------------- - -__device__ __forceinline__ float SoftplusStable(float x) { - float ax = fabsf(x); - float mx = x > 0.0f ? x : 0.0f; - return mx + log1pf(expf(-ax)); -} - -__device__ __forceinline__ double SoftplusStable(double x) { - double ax = fabs(x); - double mx = x > 0.0 ? x : 0.0; - return mx + log1p(exp(-ax)); -} - -// ---------------------------------------------------------------------------- -// Float kernel -// ---------------------------------------------------------------------------- - -__global__ void SoftplusKernelFloat(const float* input, float* output, int size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - output[idx] = SoftplusStable(input[idx]); - } -} - -// ---------------------------------------------------------------------------- -// Double kernel -// ---------------------------------------------------------------------------- - -__global__ void SoftplusKernelDouble(const double* input, double* output, int size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - output[idx] = SoftplusStable(input[idx]); - } -} - -// ---------------------------------------------------------------------------- -// Half (float16) kernel - compute in float for stability/precision -// ---------------------------------------------------------------------------- - -__global__ void SoftplusKernelHalf(const half* input, half* output, int size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - float x = __half2float(input[idx]); - float y = SoftplusStable(x); - output[idx] = __float2half(y); - } -} - -// ---------------------------------------------------------------------------- -// BFloat16 kernel - compute in float for stability/precision -// ---------------------------------------------------------------------------- - -__global__ void SoftplusKernelBFloat16(const __mt_bfloat16* input, - __mt_bfloat16* output, - int size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < size) { - float x = __bfloat162float(input[idx]); - float y = SoftplusStable(x); - output[idx] = __float2bfloat16(y); - } -} - -// ---------------------------------------------------------------------------- -// Launcher functions - called from C++ code -// ---------------------------------------------------------------------------- - -#define DEFINE_SOFTPLUS_LAUNCHER(name, kernel, T) \ - void name(const T* input, T* output, int size, musaStream_t stream) { \ - const int threads_per_block = 256; \ - const int blocks = (size + threads_per_block - 1) / threads_per_block; \ - kernel<<>>(input, output, size); \ - } - -DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelFloat, SoftplusKernelFloat, float) -DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelDouble, SoftplusKernelDouble, double) -DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelHalf, SoftplusKernelHalf, half) -DEFINE_SOFTPLUS_LAUNCHER(LaunchSoftplusKernelBFloat16, SoftplusKernelBFloat16, __mt_bfloat16) - -#undef DEFINE_SOFTPLUS_LAUNCHER - -} // extern "C" \ No newline at end of file diff --git a/musa_ext/kernels/musa_softplus_op.cc b/musa_ext/kernels/musa_softplus_op.cc index 1eb5f01..f5976c4 100644 --- a/musa_ext/kernels/musa_softplus_op.cc +++ b/musa_ext/kernels/musa_softplus_op.cc @@ -1,119 +1,127 @@ -#include "tensorflow/core/framework/bfloat16.h" +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" #include "utils_op.h" -// ============================================================================ -// MUSA Softplus custom kernel launcher declarations from -// musa_softplus_kernel.mu -// ============================================================================ - -extern "C" { -void LaunchSoftplusKernelFloat(const float* input, float* output, int size, - musaStream_t stream); -void LaunchSoftplusKernelDouble(const double* input, double* output, int size, - musaStream_t stream); -void LaunchSoftplusKernelHalf(const void* input, void* output, int size, - musaStream_t stream); -void LaunchSoftplusKernelBFloat16(const void* input, void* output, int size, - musaStream_t stream); -} - namespace tensorflow { namespace musa { -// ============================================================================ -// Common implementation for Softplus Compute -// ============================================================================ - -template -void SoftplusCompute(OpKernelContext* ctx, - void (*launcher)(const T*, T*, int, musaStream_t)) { - OP_REQUIRES(ctx, ctx->num_inputs() == 1, - errors::InvalidArgument("Softplus expects 1 input, got ", - ctx->num_inputs())); - - const Tensor& input = ctx->input(0); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); - - const int64 num_elements = input.NumElements(); - if (num_elements == 0) return; - - musaStream_t stream = - reinterpret_cast(GetHandleByCtx(ctx).GetStream()); - - const void* input_ptr = input.tensor_data().data(); - void* output_ptr = const_cast(output->tensor_data().data()); - - launcher(reinterpret_cast(input_ptr), - reinterpret_cast(output_ptr), static_cast(num_elements), - stream); -} - -// ============================================================================ -// Softplus operator class -// ============================================================================ - template class MusaSoftplusOp : public MusaOpKernel { public: explicit MusaSoftplusOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - SoftplusCompute(ctx, GetLauncher()); + const Tensor& input = ctx->input(0); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); + + if (input.NumElements() == 0) return; + + auto& handle = GetHandleByCtx(ctx); + + Tensor t_abs_tf, t_neg_abs_tf, t_exp_tf, t_exp_add1_tf, t_log_tf, t_relu_tf; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_abs_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_neg_abs_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_exp_add1_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_log_tf)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(input.dtype(), input.shape(), &t_relu_tf)); + + ::musa::dnn::Tensor x = CreateMTensor(input, format_); + ::musa::dnn::Tensor y = CreateMTensor(*output, format_); + ::musa::dnn::Tensor t_abs = CreateMTensor(t_abs_tf, format_); + ::musa::dnn::Tensor t_nabs = CreateMTensor(t_neg_abs_tf, format_); + ::musa::dnn::Tensor t_exp = CreateMTensor(t_exp_tf, format_); + ::musa::dnn::Tensor t_e1 = CreateMTensor(t_exp_add1_tf, format_); + ::musa::dnn::Tensor t_log = CreateMTensor(t_log_tf, format_); + ::musa::dnn::Tensor t_relu = CreateMTensor(t_relu_tf, format_); + + auto check_status = [&](::musa::dnn::Status status, const char* msg) { + OP_REQUIRES(ctx, status == ::musa::dnn::Status::SUCCESS, + errors::Internal(msg, " Status: ", static_cast(status))); + }; + + // 1) t_abs = abs(x) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::ABS), + "MUSA Softplus ABS SetMode failed."); + auto status = op.Run(handle, t_abs, x); + check_status(status, "MUSA Softplus ABS execution failed."); + } + + // 2) t_neg_abs = -t_abs (via unary MUL with alpha = -1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::MUL), + "MUSA Softplus MUL(SetMode) failed."); + check_status(op.SetAlpha(-1.0), + "MUSA Softplus MUL(SetAlpha=-1) failed."); + auto status = op.Run(handle, t_nabs, t_abs); + check_status(status, "MUSA Softplus MUL(-1) execution failed."); + } + + // 3) t_exp = exp(t_neg_abs) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::EXP), + "MUSA Softplus EXP SetMode failed."); + auto status = op.Run(handle, t_exp, t_nabs); + check_status(status, "MUSA Softplus EXP execution failed."); + } + + // 4) t_exp_add1 = t_exp + 1 (via unary ADD with alpha = 1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::ADD), + "MUSA Softplus ADD(SetMode) failed."); + check_status(op.SetAlpha(1.0), + "MUSA Softplus ADD(SetAlpha=1) failed."); + auto status = op.Run(handle, t_e1, t_exp); + check_status(status, "MUSA Softplus ADD(+1) execution failed."); + } + + // 5) t_log = log(t_exp_add1) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::LOG), + "MUSA Softplus LOG SetMode failed."); + auto status = op.Run(handle, t_log, t_e1); + check_status(status, "MUSA Softplus LOG execution failed."); + } + + // 6) t_relu = relu(x) == max(x, 0) + { + ::musa::dnn::Unary op; + check_status(op.SetMode(::musa::dnn::Unary::Mode::RELU), + "MUSA Softplus RELU SetMode failed."); + auto status = op.Run(handle, t_relu, x); + check_status(status, "MUSA Softplus RELU execution failed."); + } + + // 7) y = t_relu + t_log + { + ::musa::dnn::Binary op; + check_status(op.SetMode(::musa::dnn::Binary::Mode::ADD), + "MUSA Softplus Binary ADD SetMode failed."); + auto status = op.Run(handle, y, t_relu, t_log); + check_status(status, "MUSA Softplus final ADD execution failed."); + } } - - private: - static void (*GetLauncher())(const T*, T*, int, musaStream_t); }; -// ============================================================================ -// Launcher function getters - specialized for each type -// ============================================================================ - -#define DEFINE_SOFTPLUS_LAUNCHER_GETTER(T, launcher, input_cast, output_cast) \ - template <> \ - void (*MusaSoftplusOp::GetLauncher())(const T*, T*, int, musaStream_t) { \ - return [](const T* input, T* output, int size, musaStream_t stream) { \ - launcher(input_cast(input), output_cast(output), size, stream); \ - }; \ - } - -// Float / double -DEFINE_SOFTPLUS_LAUNCHER_GETTER(float, LaunchSoftplusKernelFloat, - reinterpret_cast, - reinterpret_cast) - -DEFINE_SOFTPLUS_LAUNCHER_GETTER(double, LaunchSoftplusKernelDouble, - reinterpret_cast, - reinterpret_cast) - -// Half / BFloat16 use void* bridge -DEFINE_SOFTPLUS_LAUNCHER_GETTER(Eigen::half, LaunchSoftplusKernelHalf, - reinterpret_cast, - reinterpret_cast) - -DEFINE_SOFTPLUS_LAUNCHER_GETTER(bfloat16, LaunchSoftplusKernelBFloat16, - reinterpret_cast, - reinterpret_cast) - -#undef DEFINE_SOFTPLUS_LAUNCHER_GETTER - -// ============================================================================ -// Kernel registration (TF2 only) -// Name must match TensorFlow official op: "Softplus" -// ============================================================================ - -#define REGISTER_MUSA_SOFTPLUS(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Softplus").Device("MUSA").TypeConstraint("T"), \ - MusaSoftplusOp); +#define REGISTER_MUSA_SOFTPLUS(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Softplus").Device("MUSA").TypeConstraint("T"), \ + MusaSoftplusOp) REGISTER_MUSA_SOFTPLUS(float); -REGISTER_MUSA_SOFTPLUS(double); -// REGISTER_MUSA_SOFTPLUS(Eigen::half); -// REGISTER_MUSA_SOFTPLUS(bfloat16); +REGISTER_MUSA_SOFTPLUS(Eigen::half); +REGISTER_MUSA_SOFTPLUS(bfloat16); #undef REGISTER_MUSA_SOFTPLUS diff --git a/test/softplus_op_test.py b/test/softplus_op_test.py index 02b8d60..d3f9975 100644 --- a/test/softplus_op_test.py +++ b/test/softplus_op_test.py @@ -21,202 +21,38 @@ from musa_test_utils import MUSATestCase -dtype_list = [tf.float32] -# dtype_list = [tf.float32, tf.float16, tf.bfloat16, tf.float64] class SoftplusOpTest(MUSATestCase): - """Tests for MUSA Softplus operator.""" - - def _test_softplus(self, shape, dtype, rtol=1e-5, atol=1e-8, - value_range=(-10.0, 10.0)): - """Test softplus operation with given shape and dtype.""" - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + def _test_softplus(self, shape, dtype, rtol=1e-5, atol=1e-5): + np_dtype = dtype.as_numpy_dtype + if dtype == tf.bfloat16: + np_dtype = np.float32 - # Generate input tensor - input_np = np.random.uniform(value_range[0], value_range[1], size=shape).astype(np_dtype) - input_tf = tf.constant(input_np, dtype=dtype) + if dtype == tf.float16: + low, high = -5.0, 5.0 + elif dtype == tf.bfloat16: + low, high = -3.0, 3.0 + else: + low, high = -10.0, 10.0 - def softplus_wrapper(x): - return tf.nn.softplus(x) + x_np = np.random.uniform(low, high, size=shape).astype(np_dtype) + x = tf.constant(x_np, dtype=dtype) self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusBasic(self): - """Test Softplus with basic shapes.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - self._test_softplus([10], dtype, rtol=rtol, atol=atol) - self._test_softplus([128], dtype, rtol=rtol, atol=atol) - self._test_softplus([32, 64], dtype, rtol=rtol, atol=atol) - self._test_softplus([8, 16, 32], dtype, rtol=rtol, atol=atol) - - def testSoftplusLargeTensor(self): - """Test Softplus with large tensors.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - self._test_softplus([2048, 2048], dtype, rtol=rtol, atol=atol) - - def testSoftplusEmptyTensor(self): - """Test Softplus with empty tensors.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - # Empty tensor with shape [0] - self._test_softplus([0], dtype, rtol=rtol, atol=atol) - # Empty tensor with shape [0, 5] - self._test_softplus([0, 5], dtype, rtol=rtol, atol=atol) - - def testSoftplusZeroValues(self): - """Test Softplus with all-zero values.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - input_np = np.zeros((256,), dtype=np_dtype) - input_tf = tf.constant(input_np, dtype=dtype) - - def softplus_wrapper(x): - return tf.nn.softplus(x) - - self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusMixedPositiveNegative(self): - """Test Softplus with mixed positive and negative values.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - input_np = np.random.uniform(-50, 50, size=(512,)).astype(np_dtype) - input_tf = tf.constant(input_np, dtype=dtype) - - def softplus_wrapper(x): - return tf.nn.softplus(x) - - self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusExtremeValues(self): - """Test Softplus with extreme values for numerical stability.""" - for dtype in dtype_list: - rtol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 2e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - - # Use a conservative range for fp16/bf16 to reduce overflow-related mismatch - if dtype == tf.float16: - vals = np.array([-20.0, -10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0, 15.0], - dtype=np_dtype) - elif dtype == tf.bfloat16: - vals = np.array([-50.0, -20.0, -10.0, -1.0, 0.0, 1.0, 10.0, 20.0, 50.0], - dtype=np_dtype) - elif dtype == tf.float32: - vals = np.array([-100.0, -50.0, -20.0, -1.0, 0.0, 1.0, 20.0, 50.0, 100.0], - dtype=np_dtype) - else: # tf.float64 - vals = np.array([-500.0, -100.0, -20.0, -1.0, 0.0, 1.0, 20.0, 100.0, 500.0], - dtype=np_dtype) - - input_tf = tf.constant(vals, dtype=dtype) - - def softplus_wrapper(x): - return tf.nn.softplus(x) - - self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusSmallValues(self): - """Test Softplus with very small values around zero.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-6 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - - if dtype == tf.float16: - eps = np.finfo(np.float16).eps - elif dtype == tf.float64: - eps = np.finfo(np.float64).eps - else: - eps = np.finfo(np.float32).eps - - vals = np.array([-10 * eps, -eps, 0.0, eps, 10 * eps], dtype=np_dtype) - input_tf = tf.constant(vals, dtype=dtype) - - def softplus_wrapper(x): - return tf.nn.softplus(x) - - self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusScalar(self): - """Test Softplus with scalar input.""" - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - input_np = np.array(1.2345, dtype=np_dtype) - input_tf = tf.constant(input_np, dtype=dtype) - - def softplus_wrapper(x): - return tf.nn.softplus(x) - - self._compare_cpu_musa_results( - softplus_wrapper, [input_tf], dtype, rtol=rtol, atol=atol) - - def testSoftplusDifferentShapes(self): - """Test Softplus on tensors with various dimensions.""" - test_shapes = [ - [1], - [7], - [1, 1], - [3, 5], - [2, 3, 4], - [2, 2, 2, 2], - ] - for dtype in dtype_list: - rtol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-5 - atol = 1e-2 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - - for shape in test_shapes: - self._test_softplus(shape, dtype, rtol=rtol, atol=atol) - - def testSoftplusMonotonicitySanity(self): - """Sanity test: softplus should be monotonic increasing.""" - # This checks both functional correctness and consistency on MUSA. - for dtype in dtype_list: - np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype - x_np = np.linspace(-10.0, 10.0, num=1000).astype(np_dtype) - x_tf = tf.constant(x_np, dtype=dtype) - - with tf.device('/device:MUSA:0'): - y = tf.nn.softplus(x_tf) - - y_np = tf.cast(y, tf.float32).numpy() if dtype in [tf.float16, tf.bfloat16] else y.numpy() - diffs = np.diff(y_np) - - # Allow tiny numerical noise for low precision - tol = 1e-4 if dtype in [tf.float16, tf.bfloat16] else 1e-8 - self.assertTrue(np.all(diffs >= -tol), - msg="Softplus output is not monotonic increasing.") - - def testSoftplusInvalidType(self): - """Test Softplus with invalid (non-floating) dtype should fail.""" - for dtype in [tf.int32, tf.int64]: - x = tf.constant([1, 2, 3], dtype=dtype) - with self.assertRaises((tf.errors.InvalidArgumentError, TypeError)): - with tf.device('/device:MUSA:0'): - _ = tf.nn.softplus(x) + tf.math.softplus, + [x], + dtype, + rtol=rtol, + atol=atol + ) + def testSoftplusFloat32(self): + self._test_softplus([10, 10], tf.float32, rtol=1e-4, atol=1e-4) + + def testSoftplusFloat16(self): + self._test_softplus([2, 3, 4], tf.float16, rtol=1e-2, atol=1e-2) + + def testSoftplusBFloat16(self): + self._test_softplus([10, 10], tf.bfloat16, rtol=1e-1, atol=1e-1) if __name__ == "__main__":