From a9aa29e62a40b6910539772bb12cbb1cf24b1e88 Mon Sep 17 00:00:00 2001 From: albert Date: Thu, 12 Feb 2026 13:55:27 +0800 Subject: [PATCH 1/4] feat: isNanOp --- musa_ext/kernels/musa_addn_kernel.mu | 119 +++++++++++++++++++++++++ musa_ext/kernels/musa_isnan_kernel.mu | 86 ++++++++++++++++++ musa_ext/kernels/musa_isnan_op.cc | 63 ++++++++++++++ test/isnan_op_test.py | 121 ++++++++++++++++++++++++++ 4 files changed, 389 insertions(+) create mode 100644 musa_ext/kernels/musa_addn_kernel.mu create mode 100644 musa_ext/kernels/musa_isnan_kernel.mu create mode 100644 musa_ext/kernels/musa_isnan_op.cc create mode 100644 test/isnan_op_test.py diff --git a/musa_ext/kernels/musa_addn_kernel.mu b/musa_ext/kernels/musa_addn_kernel.mu new file mode 100644 index 0000000..374519c --- /dev/null +++ b/musa_ext/kernels/musa_addn_kernel.mu @@ -0,0 +1,119 @@ +#include +#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-pragmas" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/bfloat16.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#pragma GCC diagnostic pop + +namespace tensorflow { +namespace musa { + +// 设备函数:加载不同数据类型的值 +__device__ __forceinline__ float LoadFloat(const float* p) { return *p; } +__device__ __forceinline__ void StoreFloat(float* p, float v) { *p = v; } + +__device__ __forceinline__ float LoadFloat(const Eigen::half* p) { + const __half* h_ptr = reinterpret_cast(p); + return __half2float(*h_ptr); +} + +__device__ __forceinline__ void StoreFloat(Eigen::half* p, float v) { + __half h = __float2half(v); + *reinterpret_cast<__half*>(p) = h; +} + +__device__ __forceinline__ float LoadFloat(const bfloat16* p) { + float res = 0.0f; + uint16_t* b_ptr = (uint16_t*)p; + uint32_t* f_ptr = (uint32_t*)&res; + *f_ptr = ((uint32_t)(*b_ptr)) << 16; + return res; +} + +__device__ __forceinline__ void StoreFloat(bfloat16* p, float v) { + uint32_t* f_ptr = (uint32_t*)&v; + uint16_t b_val = (*f_ptr) >> 16; + *reinterpret_cast(p) = b_val; +} + +// 整数类型的支持 +__device__ __forceinline__ int32_t LoadInt32(const int32_t* p) { return *p; } +__device__ __forceinline__ void StoreInt32(int32_t* p, int32_t v) { *p = v; } + +// 使用 tensorflow::int64 而不是 int64_t +__device__ __forceinline__ tensorflow::int64 LoadInt64(const tensorflow::int64* p) { return *p; } +__device__ __forceinline__ void StoreInt64(tensorflow::int64* p, tensorflow::int64 v) { *p = v; } + +// 双精度支持 +__device__ __forceinline__ double LoadDouble(const double* p) { return *p; } +__device__ __forceinline__ void StoreDouble(double* p, double v) { *p = v; } + +// 主要的AddN kernel模板 - 使用 const T* const* 参数类型 +template +__global__ void AddNKernel(const T* const* inputs, T* output, int num_inputs, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + // 使用适当的数据类型进行累加 + T sum = inputs[0][idx]; + for (int i = 1; i < num_inputs; ++i) { + sum += inputs[i][idx]; + } + output[idx] = sum; + } +} + +// 特化版本:使用float中间计算(适用于半精度) +template <> +__global__ void AddNKernel(const Eigen::half* const* inputs, Eigen::half* output, int num_inputs, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float sum = LoadFloat(&inputs[0][idx]); + for (int i = 1; i < num_inputs; ++i) { + sum += LoadFloat(&inputs[i][idx]); + } + StoreFloat(&output[idx], sum); + } +} + +template <> +__global__ void AddNKernel(const bfloat16* const* inputs, bfloat16* output, int num_inputs, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float sum = LoadFloat(&inputs[0][idx]); + for (int i = 1; i < num_inputs; ++i) { + sum += LoadFloat(&inputs[i][idx]); + } + StoreFloat(&output[idx], sum); + } +} + +// 启动函数 - 使用 const T* const* 参数类型 +template +void LaunchAddN(const T* const* inputs, T* output, int num_inputs, int n, musaStream_t stream) { + if (n <= 0 || num_inputs <= 0) return; + + int threads = 256; + int blocks = (n + threads - 1) / threads; + + AddNKernel<<>>(inputs, output, num_inputs, n); + + // 检查kernel启动错误 + musaError_t err = musaGetLastError(); + if (err != musaSuccess) { + // 错误处理将在C++层处理 + } +} + +// 显式实例化 - 使用 tensorflow::int64 +// LaunchAddN 是模板函数 必须有显示实例化 编译器才会在编译阶段生成函数代码 +template void LaunchAddN(const float* const*, float*, int, int, musaStream_t); +template void LaunchAddN(const double* const*, double*, int, int, musaStream_t); +template void LaunchAddN(const Eigen::half* const*, Eigen::half*, int, int, musaStream_t); +template void LaunchAddN(const bfloat16* const*, bfloat16*, int, int, musaStream_t); +template void LaunchAddN(const int32_t* const*, int32_t*, int, int, musaStream_t); +template void LaunchAddN(const tensorflow::int64* const*, tensorflow::int64*, int, int, musaStream_t); + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_isnan_kernel.mu b/musa_ext/kernels/musa_isnan_kernel.mu new file mode 100644 index 0000000..8346f0b --- /dev/null +++ b/musa_ext/kernels/musa_isnan_kernel.mu @@ -0,0 +1,86 @@ +#include +#include +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-pragmas" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/bfloat16.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#pragma GCC diagnostic pop + +namespace tensorflow { +namespace musa { + +// --------- 工具:half / bfloat16 转 float --------- +__device__ __forceinline__ float LoadFloat(const float* p) { return *p; } + +__device__ __forceinline__ float LoadFloat(const Eigen::half* p) { + const __half* h_ptr = reinterpret_cast(p); + return __half2float(*h_ptr); +} + +__device__ __forceinline__ float LoadFloat(const bfloat16* p) { + float res = 0.0f; + const uint16_t* b_ptr = reinterpret_cast(p); + uint32_t* f_ptr = reinterpret_cast(&res); + *f_ptr = (static_cast(*b_ptr)) << 16; + return res; +} + +// --------- isnan 判定(float/double 直接用 isnan;half/bf16 转 float) --------- +__device__ __forceinline__ bool IsNanValue(float v) { return isnan(v); } +__device__ __forceinline__ bool IsNanValue(double v) { return isnan(v); } + +// --------- Kernel:通用模板(float/double)--------- +template +__global__ void IsNanKernel(const T* input, bool* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = IsNanValue(input[idx]); + } +} + +// --------- 特化:Eigen::half --------- +template <> +__global__ void IsNanKernel(const Eigen::half* input, bool* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float v = LoadFloat(&input[idx]); + output[idx] = IsNanValue(v); + } +} + +// --------- 特化:bfloat16 --------- +template <> +__global__ void IsNanKernel(const bfloat16* input, bool* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float v = LoadFloat(&input[idx]); + output[idx] = IsNanValue(v); + } +} + +// --------- Launch --------- +template +void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream) { + if (n <= 0) return; + + int threads = 256; + int blocks = (n + threads - 1) / threads; + + IsNanKernel<<>>(input, output, n); + + // kernel 启动错误检查(和你 AddN 风格一致:错误在上层处理也行) + musaError_t err = musaGetLastError(); + (void)err; +} + +// 显式实例化 +template void LaunchIsNan(const float*, bool*, int, musaStream_t); +template void LaunchIsNan(const double*, bool*, int, musaStream_t); +template void LaunchIsNan(const Eigen::half*, bool*, int, musaStream_t); +template void LaunchIsNan(const bfloat16*, bool*, int, musaStream_t); + +} // namespace musa +} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_isnan_op.cc b/musa_ext/kernels/musa_isnan_op.cc new file mode 100644 index 0000000..960d4d2 --- /dev/null +++ b/musa_ext/kernels/musa_isnan_op.cc @@ -0,0 +1,63 @@ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/bfloat16.h" +#include "utils_op.h" +#include "mu/device/musa_device.h" + +// 声明 kernel 启动函数 +namespace tensorflow { +namespace musa { +template +void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream); +} // namespace musa +} // namespace tensorflow + +namespace tensorflow { +namespace musa { + +template +class MusaIsNanOp : public MusaOpKernel { + public: + explicit MusaIsNanOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& x = ctx->input(0); + + Tensor* y = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); + + const int64_t n64 = y->NumElements(); + if (n64 == 0) return; + + // TF Tensor 的 NumElements() 是 int64,这里 kernel 用 int + OP_REQUIRES(ctx, n64 <= static_cast(std::numeric_limits::max()), + errors::InvalidArgument("IsNan: tensor is too large, num_elements=", n64)); + + const int n = static_cast(n64); + + const T* x_ptr = x.flat().data(); + bool* y_ptr = y->flat().data(); + + auto* device = GetDeviceByCtx(ctx); + auto stream = device->GetStream(); + + LaunchIsNan(x_ptr, y_ptr, n, stream); + } +}; + +// 只注册 MUSA 设备上的 kernel;Op 本体(REGISTER_OP)由 TF Core 提供 +#define REGISTER_MUSA_ISNAN(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("IsNan") \ + .Device(DEVICE_MTGPU) \ + .TypeConstraint("T"), \ + MusaIsNanOp); + +REGISTER_MUSA_ISNAN(float); +REGISTER_MUSA_ISNAN(double); +REGISTER_MUSA_ISNAN(Eigen::half); +REGISTER_MUSA_ISNAN(bfloat16); + +#undef REGISTER_MUSA_ISNAN + +} // namespace musa +} // namespace tensorflow diff --git a/test/isnan_op_test.py b/test/isnan_op_test.py new file mode 100644 index 0000000..48ef8be --- /dev/null +++ b/test/isnan_op_test.py @@ -0,0 +1,121 @@ +# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MUSA IsNan operator. + +This test assumes: +- TensorFlow core has registered the Op 'IsNan' (math_ops.cc). +- The MUSA plugin registers a DEVICE_MTGPU kernel for 'IsNan'. +""" +import numpy as np +import tensorflow as tf + +from musa_test_utils import MUSATestCase + + +class IsNanOpTest(MUSATestCase): + """Tests for MUSA IsNan operator.""" + + def _make_input(self, shape, dtype, inject_nan=True, fill_value=None, include_inf=False): + """Create a numpy input array for a given TF dtype.""" + np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype + + if fill_value is not None: + x_np = np.full(shape, fill_value, dtype=np_dtype) + else: + x_np = np.random.uniform(-1.0, 1.0, size=shape).astype(np_dtype) + + if include_inf and x_np.size > 0: + x_np.flat[0] = np.inf + if x_np.size > 1: + x_np.flat[1] = -np.inf + + if inject_nan and x_np.size > 0: + # Put NaNs in deterministic positions + x_np.flat[0] = np.nan + x_np.flat[x_np.size // 2] = np.nan + x_np.flat[-1] = np.nan + + return x_np + + def _test_isnan(self, shape, dtype, inject_nan=True, fill_value=None, include_inf=False): + """Test IsNan operation with given shape and dtype.""" + x_np = self._make_input(shape, dtype, inject_nan=inject_nan, + fill_value=fill_value, include_inf=include_inf) + x_tf = tf.constant(x_np, dtype=dtype) + + # Numeric proxy so musa_test_utils._compare_cpu_musa_results can use assertAllClose. + # (IsNan's real output is bool; bool comparison is done below.) + def isnan_proxy(x): + return tf.cast(tf.math.is_nan(x), tf.float32) + + self._compare_cpu_musa_results(isnan_proxy, [x_tf], dtype, rtol=0.0, atol=0.0) + + # Also validate the true output dtype/shape and exact bool equality CPU vs MUSA. + with tf.device("/CPU:0"): + cpu_bool = tf.math.is_nan(x_tf) + with tf.device("/device:MUSA:0"): + musa_bool = tf.math.is_nan(x_tf) + + self.assertEqual(cpu_bool.dtype, tf.bool) + self.assertEqual(musa_bool.dtype, tf.bool) + self.assertAllEqual(cpu_bool.shape.as_list(), x_tf.shape.as_list()) + self.assertAllEqual(musa_bool.shape.as_list(), x_tf.shape.as_list()) + self.assertAllEqual(cpu_bool.numpy(), musa_bool.numpy()) + + def testIsNanSmall(self): + """Small tensor correctness.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + self._test_isnan([10], dtype) + + def testIsNanLarge(self): + """Larger tensor correctness.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + self._test_isnan([256, 256], dtype) + + def testIsNanEmptyTensor(self): + """Empty tensors should return empty bool tensors with same shape.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + self._test_isnan([0], dtype, inject_nan=False) + self._test_isnan([0, 5], dtype, inject_nan=False) + + def testIsNanNoNaNs(self): + """If there are no NaNs, all outputs should be False.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + self._test_isnan([1024], dtype, inject_nan=False, include_inf=False) + + def testIsNanAllNaNs(self): + """All NaNs should yield all True.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + # For float16/bf16, np.nan is representable; TF will carry NaN. + self._test_isnan([128], dtype, inject_nan=False, fill_value=np.nan) + + def testIsNanWithInfs(self): + """Infs are not NaNs; only NaNs should be True.""" + for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: + self._test_isnan([64], dtype, inject_nan=True, include_inf=True) + + def testIsNanInvalidDType(self): + """IsNan should reject non-floating types per TF op definition.""" + for dtype in [tf.int32, tf.int64]: + x = tf.constant([1, 2, 3], dtype=dtype) + # Depending on TF eager tracing path, TypeError or InvalidArgumentError may occur. + with self.assertRaises((TypeError, tf.errors.InvalidArgumentError)): + with tf.device("/device:MUSA:0"): + _ = tf.math.is_nan(x) + + +if __name__ == "__main__": + tf.test.main() \ No newline at end of file From da50648f0fe8bac6c0d064e4b2ec25fc8783fb24 Mon Sep 17 00:00:00 2001 From: albert Date: Sat, 14 Feb 2026 10:34:13 +0800 Subject: [PATCH 2/4] delete useless addn kernel --- musa_ext/kernels/musa_addn_kernel.mu | 119 --------------------------- 1 file changed, 119 deletions(-) delete mode 100644 musa_ext/kernels/musa_addn_kernel.mu diff --git a/musa_ext/kernels/musa_addn_kernel.mu b/musa_ext/kernels/musa_addn_kernel.mu deleted file mode 100644 index 374519c..0000000 --- a/musa_ext/kernels/musa_addn_kernel.mu +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-pragmas" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/bfloat16.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#pragma GCC diagnostic pop - -namespace tensorflow { -namespace musa { - -// 设备函数:加载不同数据类型的值 -__device__ __forceinline__ float LoadFloat(const float* p) { return *p; } -__device__ __forceinline__ void StoreFloat(float* p, float v) { *p = v; } - -__device__ __forceinline__ float LoadFloat(const Eigen::half* p) { - const __half* h_ptr = reinterpret_cast(p); - return __half2float(*h_ptr); -} - -__device__ __forceinline__ void StoreFloat(Eigen::half* p, float v) { - __half h = __float2half(v); - *reinterpret_cast<__half*>(p) = h; -} - -__device__ __forceinline__ float LoadFloat(const bfloat16* p) { - float res = 0.0f; - uint16_t* b_ptr = (uint16_t*)p; - uint32_t* f_ptr = (uint32_t*)&res; - *f_ptr = ((uint32_t)(*b_ptr)) << 16; - return res; -} - -__device__ __forceinline__ void StoreFloat(bfloat16* p, float v) { - uint32_t* f_ptr = (uint32_t*)&v; - uint16_t b_val = (*f_ptr) >> 16; - *reinterpret_cast(p) = b_val; -} - -// 整数类型的支持 -__device__ __forceinline__ int32_t LoadInt32(const int32_t* p) { return *p; } -__device__ __forceinline__ void StoreInt32(int32_t* p, int32_t v) { *p = v; } - -// 使用 tensorflow::int64 而不是 int64_t -__device__ __forceinline__ tensorflow::int64 LoadInt64(const tensorflow::int64* p) { return *p; } -__device__ __forceinline__ void StoreInt64(tensorflow::int64* p, tensorflow::int64 v) { *p = v; } - -// 双精度支持 -__device__ __forceinline__ double LoadDouble(const double* p) { return *p; } -__device__ __forceinline__ void StoreDouble(double* p, double v) { *p = v; } - -// 主要的AddN kernel模板 - 使用 const T* const* 参数类型 -template -__global__ void AddNKernel(const T* const* inputs, T* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - // 使用适当的数据类型进行累加 - T sum = inputs[0][idx]; - for (int i = 1; i < num_inputs; ++i) { - sum += inputs[i][idx]; - } - output[idx] = sum; - } -} - -// 特化版本:使用float中间计算(适用于半精度) -template <> -__global__ void AddNKernel(const Eigen::half* const* inputs, Eigen::half* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - float sum = LoadFloat(&inputs[0][idx]); - for (int i = 1; i < num_inputs; ++i) { - sum += LoadFloat(&inputs[i][idx]); - } - StoreFloat(&output[idx], sum); - } -} - -template <> -__global__ void AddNKernel(const bfloat16* const* inputs, bfloat16* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - float sum = LoadFloat(&inputs[0][idx]); - for (int i = 1; i < num_inputs; ++i) { - sum += LoadFloat(&inputs[i][idx]); - } - StoreFloat(&output[idx], sum); - } -} - -// 启动函数 - 使用 const T* const* 参数类型 -template -void LaunchAddN(const T* const* inputs, T* output, int num_inputs, int n, musaStream_t stream) { - if (n <= 0 || num_inputs <= 0) return; - - int threads = 256; - int blocks = (n + threads - 1) / threads; - - AddNKernel<<>>(inputs, output, num_inputs, n); - - // 检查kernel启动错误 - musaError_t err = musaGetLastError(); - if (err != musaSuccess) { - // 错误处理将在C++层处理 - } -} - -// 显式实例化 - 使用 tensorflow::int64 -// LaunchAddN 是模板函数 必须有显示实例化 编译器才会在编译阶段生成函数代码 -template void LaunchAddN(const float* const*, float*, int, int, musaStream_t); -template void LaunchAddN(const double* const*, double*, int, int, musaStream_t); -template void LaunchAddN(const Eigen::half* const*, Eigen::half*, int, int, musaStream_t); -template void LaunchAddN(const bfloat16* const*, bfloat16*, int, int, musaStream_t); -template void LaunchAddN(const int32_t* const*, int32_t*, int, int, musaStream_t); -template void LaunchAddN(const tensorflow::int64* const*, tensorflow::int64*, int, int, musaStream_t); - -} // namespace musa -} // namespace tensorflow \ No newline at end of file From e0a2c8282434bd7bc31375f5fcc54ca2eee7a996 Mon Sep 17 00:00:00 2001 From: albert Date: Sat, 14 Feb 2026 10:34:13 +0800 Subject: [PATCH 3/4] delete useless addn kernel --- musa_ext/kernels/musa_addn_kernel.mu | 119 -------------------------- musa_ext/kernels/musa_isnan_kernel.mu | 18 ++-- musa_ext/kernels/musa_isnan_op.cc | 17 ++-- test/isnan_op_test.py | 38 ++++---- 4 files changed, 41 insertions(+), 151 deletions(-) delete mode 100644 musa_ext/kernels/musa_addn_kernel.mu diff --git a/musa_ext/kernels/musa_addn_kernel.mu b/musa_ext/kernels/musa_addn_kernel.mu deleted file mode 100644 index 374519c..0000000 --- a/musa_ext/kernels/musa_addn_kernel.mu +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-pragmas" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/bfloat16.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#pragma GCC diagnostic pop - -namespace tensorflow { -namespace musa { - -// 设备函数:加载不同数据类型的值 -__device__ __forceinline__ float LoadFloat(const float* p) { return *p; } -__device__ __forceinline__ void StoreFloat(float* p, float v) { *p = v; } - -__device__ __forceinline__ float LoadFloat(const Eigen::half* p) { - const __half* h_ptr = reinterpret_cast(p); - return __half2float(*h_ptr); -} - -__device__ __forceinline__ void StoreFloat(Eigen::half* p, float v) { - __half h = __float2half(v); - *reinterpret_cast<__half*>(p) = h; -} - -__device__ __forceinline__ float LoadFloat(const bfloat16* p) { - float res = 0.0f; - uint16_t* b_ptr = (uint16_t*)p; - uint32_t* f_ptr = (uint32_t*)&res; - *f_ptr = ((uint32_t)(*b_ptr)) << 16; - return res; -} - -__device__ __forceinline__ void StoreFloat(bfloat16* p, float v) { - uint32_t* f_ptr = (uint32_t*)&v; - uint16_t b_val = (*f_ptr) >> 16; - *reinterpret_cast(p) = b_val; -} - -// 整数类型的支持 -__device__ __forceinline__ int32_t LoadInt32(const int32_t* p) { return *p; } -__device__ __forceinline__ void StoreInt32(int32_t* p, int32_t v) { *p = v; } - -// 使用 tensorflow::int64 而不是 int64_t -__device__ __forceinline__ tensorflow::int64 LoadInt64(const tensorflow::int64* p) { return *p; } -__device__ __forceinline__ void StoreInt64(tensorflow::int64* p, tensorflow::int64 v) { *p = v; } - -// 双精度支持 -__device__ __forceinline__ double LoadDouble(const double* p) { return *p; } -__device__ __forceinline__ void StoreDouble(double* p, double v) { *p = v; } - -// 主要的AddN kernel模板 - 使用 const T* const* 参数类型 -template -__global__ void AddNKernel(const T* const* inputs, T* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - // 使用适当的数据类型进行累加 - T sum = inputs[0][idx]; - for (int i = 1; i < num_inputs; ++i) { - sum += inputs[i][idx]; - } - output[idx] = sum; - } -} - -// 特化版本:使用float中间计算(适用于半精度) -template <> -__global__ void AddNKernel(const Eigen::half* const* inputs, Eigen::half* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - float sum = LoadFloat(&inputs[0][idx]); - for (int i = 1; i < num_inputs; ++i) { - sum += LoadFloat(&inputs[i][idx]); - } - StoreFloat(&output[idx], sum); - } -} - -template <> -__global__ void AddNKernel(const bfloat16* const* inputs, bfloat16* output, int num_inputs, int n) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) { - float sum = LoadFloat(&inputs[0][idx]); - for (int i = 1; i < num_inputs; ++i) { - sum += LoadFloat(&inputs[i][idx]); - } - StoreFloat(&output[idx], sum); - } -} - -// 启动函数 - 使用 const T* const* 参数类型 -template -void LaunchAddN(const T* const* inputs, T* output, int num_inputs, int n, musaStream_t stream) { - if (n <= 0 || num_inputs <= 0) return; - - int threads = 256; - int blocks = (n + threads - 1) / threads; - - AddNKernel<<>>(inputs, output, num_inputs, n); - - // 检查kernel启动错误 - musaError_t err = musaGetLastError(); - if (err != musaSuccess) { - // 错误处理将在C++层处理 - } -} - -// 显式实例化 - 使用 tensorflow::int64 -// LaunchAddN 是模板函数 必须有显示实例化 编译器才会在编译阶段生成函数代码 -template void LaunchAddN(const float* const*, float*, int, int, musaStream_t); -template void LaunchAddN(const double* const*, double*, int, int, musaStream_t); -template void LaunchAddN(const Eigen::half* const*, Eigen::half*, int, int, musaStream_t); -template void LaunchAddN(const bfloat16* const*, bfloat16*, int, int, musaStream_t); -template void LaunchAddN(const int32_t* const*, int32_t*, int, int, musaStream_t); -template void LaunchAddN(const tensorflow::int64* const*, tensorflow::int64*, int, int, musaStream_t); - -} // namespace musa -} // namespace tensorflow \ No newline at end of file diff --git a/musa_ext/kernels/musa_isnan_kernel.mu b/musa_ext/kernels/musa_isnan_kernel.mu index 8346f0b..50fd709 100644 --- a/musa_ext/kernels/musa_isnan_kernel.mu +++ b/musa_ext/kernels/musa_isnan_kernel.mu @@ -1,11 +1,11 @@ -#include -#include #include +#include +#include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-pragmas" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #pragma GCC diagnostic pop @@ -28,7 +28,8 @@ __device__ __forceinline__ float LoadFloat(const bfloat16* p) { return res; } -// --------- isnan 判定(float/double 直接用 isnan;half/bf16 转 float) --------- +// --------- isnan 判定(float/double 直接用 isnan;half/bf16 转 float) +// --------- __device__ __forceinline__ bool IsNanValue(float v) { return isnan(v); } __device__ __forceinline__ bool IsNanValue(double v) { return isnan(v); } @@ -43,7 +44,8 @@ __global__ void IsNanKernel(const T* input, bool* output, int n) { // --------- 特化:Eigen::half --------- template <> -__global__ void IsNanKernel(const Eigen::half* input, bool* output, int n) { +__global__ void IsNanKernel(const Eigen::half* input, bool* output, + int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float v = LoadFloat(&input[idx]); @@ -53,7 +55,8 @@ __global__ void IsNanKernel(const Eigen::half* input, bool* output, // --------- 特化:bfloat16 --------- template <> -__global__ void IsNanKernel(const bfloat16* input, bool* output, int n) { +__global__ void IsNanKernel(const bfloat16* input, bool* output, + int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float v = LoadFloat(&input[idx]); @@ -79,7 +82,8 @@ void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream) { // 显式实例化 template void LaunchIsNan(const float*, bool*, int, musaStream_t); template void LaunchIsNan(const double*, bool*, int, musaStream_t); -template void LaunchIsNan(const Eigen::half*, bool*, int, musaStream_t); +template void LaunchIsNan(const Eigen::half*, bool*, int, + musaStream_t); template void LaunchIsNan(const bfloat16*, bool*, int, musaStream_t); } // namespace musa diff --git a/musa_ext/kernels/musa_isnan_op.cc b/musa_ext/kernels/musa_isnan_op.cc index 960d4d2..70bed85 100644 --- a/musa_ext/kernels/musa_isnan_op.cc +++ b/musa_ext/kernels/musa_isnan_op.cc @@ -1,8 +1,8 @@ +#include "mu/device/musa_device.h" +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/bfloat16.h" #include "utils_op.h" -#include "mu/device/musa_device.h" // 声明 kernel 启动函数 namespace tensorflow { @@ -30,8 +30,10 @@ class MusaIsNanOp : public MusaOpKernel { if (n64 == 0) return; // TF Tensor 的 NumElements() 是 int64,这里 kernel 用 int - OP_REQUIRES(ctx, n64 <= static_cast(std::numeric_limits::max()), - errors::InvalidArgument("IsNan: tensor is too large, num_elements=", n64)); + OP_REQUIRES(ctx, + n64 <= static_cast(std::numeric_limits::max()), + errors::InvalidArgument( + "IsNan: tensor is too large, num_elements=", n64)); const int n = static_cast(n64); @@ -47,10 +49,9 @@ class MusaIsNanOp : public MusaOpKernel { // 只注册 MUSA 设备上的 kernel;Op 本体(REGISTER_OP)由 TF Core 提供 #define REGISTER_MUSA_ISNAN(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("IsNan") \ - .Device(DEVICE_MTGPU) \ - .TypeConstraint("T"), \ - MusaIsNanOp); + REGISTER_KERNEL_BUILDER( \ + Name("IsNan").Device(DEVICE_MTGPU).TypeConstraint("T"), \ + MusaIsNanOp); REGISTER_MUSA_ISNAN(float); REGISTER_MUSA_ISNAN(double); diff --git a/test/isnan_op_test.py b/test/isnan_op_test.py index 48ef8be..efe2b70 100644 --- a/test/isnan_op_test.py +++ b/test/isnan_op_test.py @@ -1,17 +1,18 @@ -# Copyright 2026 The TensorFlow MUSA Authors. All Rights Reserved. +#Copyright 2026 The TensorFlow MUSA Authors.All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +#Licensed under the Apache License, Version 2.0(the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +#http: // www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +#== == == == == == == == == == == == == == == == == == == == == == == == == == \ + == == == == == == == == == == == == == """Tests for MUSA IsNan operator. @@ -43,7 +44,7 @@ def _make_input(self, shape, dtype, inject_nan=True, fill_value=None, include_in x_np.flat[1] = -np.inf if inject_nan and x_np.size > 0: - # Put NaNs in deterministic positions +#Put NaNs in deterministic positions x_np.flat[0] = np.nan x_np.flat[x_np.size // 2] = np.nan x_np.flat[-1] = np.nan @@ -56,14 +57,16 @@ def _test_isnan(self, shape, dtype, inject_nan=True, fill_value=None, include_in fill_value=fill_value, include_inf=include_inf) x_tf = tf.constant(x_np, dtype=dtype) - # Numeric proxy so musa_test_utils._compare_cpu_musa_results can use assertAllClose. - # (IsNan's real output is bool; bool comparison is done below.) +#Numeric proxy so \ + musa_test_utils._compare_cpu_musa_results can use assertAllClose. +#(IsNan's real output is bool; bool comparison is done below.) def isnan_proxy(x): return tf.cast(tf.math.is_nan(x), tf.float32) self._compare_cpu_musa_results(isnan_proxy, [x_tf], dtype, rtol=0.0, atol=0.0) - # Also validate the true output dtype/shape and exact bool equality CPU vs MUSA. +#Also validate the true output dtype / \ + shape and exact bool equality CPU vs MUSA. with tf.device("/CPU:0"): cpu_bool = tf.math.is_nan(x_tf) with tf.device("/device:MUSA:0"): @@ -99,7 +102,7 @@ def testIsNanNoNaNs(self): def testIsNanAllNaNs(self): """All NaNs should yield all True.""" for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]: - # For float16/bf16, np.nan is representable; TF will carry NaN. +#For float16 / bf16, np.nan is representable; TF will carry NaN. self._test_isnan([128], dtype, inject_nan=False, fill_value=np.nan) def testIsNanWithInfs(self): @@ -111,7 +114,8 @@ def testIsNanInvalidDType(self): """IsNan should reject non-floating types per TF op definition.""" for dtype in [tf.int32, tf.int64]: x = tf.constant([1, 2, 3], dtype=dtype) - # Depending on TF eager tracing path, TypeError or InvalidArgumentError may occur. +#Depending on TF eager tracing path, \ + TypeError or InvalidArgumentError may occur. with self.assertRaises((TypeError, tf.errors.InvalidArgumentError)): with tf.device("/device:MUSA:0"): _ = tf.math.is_nan(x) From 635e93c6161ff20cf5f51226bac3eccc96aeb05a Mon Sep 17 00:00:00 2001 From: albert Date: Wed, 25 Feb 2026 23:45:08 +0800 Subject: [PATCH 4/4] feat: remove unnecessary comments --- musa_ext/kernels/musa_isnan_kernel.mu | 15 +++++++-------- musa_ext/kernels/musa_isnan_op.cc | 3 --- test/isnan_op_test.py | 10 ++-------- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/musa_ext/kernels/musa_isnan_kernel.mu b/musa_ext/kernels/musa_isnan_kernel.mu index 50fd709..7d76d98 100644 --- a/musa_ext/kernels/musa_isnan_kernel.mu +++ b/musa_ext/kernels/musa_isnan_kernel.mu @@ -12,7 +12,7 @@ namespace tensorflow { namespace musa { -// --------- 工具:half / bfloat16 转 float --------- + __device__ __forceinline__ float LoadFloat(const float* p) { return *p; } __device__ __forceinline__ float LoadFloat(const Eigen::half* p) { @@ -28,12 +28,11 @@ __device__ __forceinline__ float LoadFloat(const bfloat16* p) { return res; } -// --------- isnan 判定(float/double 直接用 isnan;half/bf16 转 float) -// --------- + __device__ __forceinline__ bool IsNanValue(float v) { return isnan(v); } __device__ __forceinline__ bool IsNanValue(double v) { return isnan(v); } -// --------- Kernel:通用模板(float/double)--------- + template __global__ void IsNanKernel(const T* input, bool* output, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -42,7 +41,7 @@ __global__ void IsNanKernel(const T* input, bool* output, int n) { } } -// --------- 特化:Eigen::half --------- + template <> __global__ void IsNanKernel(const Eigen::half* input, bool* output, int n) { @@ -53,7 +52,7 @@ __global__ void IsNanKernel(const Eigen::half* input, bool* output, } } -// --------- 特化:bfloat16 --------- + template <> __global__ void IsNanKernel(const bfloat16* input, bool* output, int n) { @@ -74,12 +73,12 @@ void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream) { IsNanKernel<<>>(input, output, n); - // kernel 启动错误检查(和你 AddN 风格一致:错误在上层处理也行) + musaError_t err = musaGetLastError(); (void)err; } -// 显式实例化 + template void LaunchIsNan(const float*, bool*, int, musaStream_t); template void LaunchIsNan(const double*, bool*, int, musaStream_t); template void LaunchIsNan(const Eigen::half*, bool*, int, diff --git a/musa_ext/kernels/musa_isnan_op.cc b/musa_ext/kernels/musa_isnan_op.cc index 70bed85..44f59cb 100644 --- a/musa_ext/kernels/musa_isnan_op.cc +++ b/musa_ext/kernels/musa_isnan_op.cc @@ -4,7 +4,6 @@ #include "tensorflow/core/framework/register_types.h" #include "utils_op.h" -// 声明 kernel 启动函数 namespace tensorflow { namespace musa { template @@ -29,7 +28,6 @@ class MusaIsNanOp : public MusaOpKernel { const int64_t n64 = y->NumElements(); if (n64 == 0) return; - // TF Tensor 的 NumElements() 是 int64,这里 kernel 用 int OP_REQUIRES(ctx, n64 <= static_cast(std::numeric_limits::max()), errors::InvalidArgument( @@ -47,7 +45,6 @@ class MusaIsNanOp : public MusaOpKernel { } }; -// 只注册 MUSA 设备上的 kernel;Op 本体(REGISTER_OP)由 TF Core 提供 #define REGISTER_MUSA_ISNAN(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("IsNan").Device(DEVICE_MTGPU).TypeConstraint("T"), \ diff --git a/test/isnan_op_test.py b/test/isnan_op_test.py index efe2b70..5440ea0 100644 --- a/test/isnan_op_test.py +++ b/test/isnan_op_test.py @@ -12,7 +12,6 @@ #See the License for the specific language governing permissions and #limitations under the License. #== == == == == == == == == == == == == == == == == == == == == == == == == == \ - == == == == == == == == == == == == == """Tests for MUSA IsNan operator. @@ -57,16 +56,12 @@ def _test_isnan(self, shape, dtype, inject_nan=True, fill_value=None, include_in fill_value=fill_value, include_inf=include_inf) x_tf = tf.constant(x_np, dtype=dtype) -#Numeric proxy so \ - musa_test_utils._compare_cpu_musa_results can use assertAllClose. -#(IsNan's real output is bool; bool comparison is done below.) def isnan_proxy(x): return tf.cast(tf.math.is_nan(x), tf.float32) self._compare_cpu_musa_results(isnan_proxy, [x_tf], dtype, rtol=0.0, atol=0.0) -#Also validate the true output dtype / \ - shape and exact bool equality CPU vs MUSA. + with tf.device("/CPU:0"): cpu_bool = tf.math.is_nan(x_tf) with tf.device("/device:MUSA:0"): @@ -114,8 +109,7 @@ def testIsNanInvalidDType(self): """IsNan should reject non-floating types per TF op definition.""" for dtype in [tf.int32, tf.int64]: x = tf.constant([1, 2, 3], dtype=dtype) -#Depending on TF eager tracing path, \ - TypeError or InvalidArgumentError may occur. + with self.assertRaises((TypeError, tf.errors.InvalidArgumentError)): with tf.device("/device:MUSA:0"): _ = tf.math.is_nan(x)