Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions musa_ext/kernels/musa_isnan_kernel.mu
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include <math.h>
#include <musa_fp16.h>
#include <musa_runtime.h>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-pragmas"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#pragma GCC diagnostic pop

namespace tensorflow {
namespace musa {


__device__ __forceinline__ float LoadFloat(const float* p) { return *p; }

__device__ __forceinline__ float LoadFloat(const Eigen::half* p) {
const __half* h_ptr = reinterpret_cast<const __half*>(p);
return __half2float(*h_ptr);
}

__device__ __forceinline__ float LoadFloat(const bfloat16* p) {
float res = 0.0f;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(p);
uint32_t* f_ptr = reinterpret_cast<uint32_t*>(&res);
*f_ptr = (static_cast<uint32_t>(*b_ptr)) << 16;
return res;
}


__device__ __forceinline__ bool IsNanValue(float v) { return isnan(v); }
__device__ __forceinline__ bool IsNanValue(double v) { return isnan(v); }


template <typename T>
__global__ void IsNanKernel(const T* input, bool* output, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = IsNanValue(input[idx]);
}
}


template <>
__global__ void IsNanKernel<Eigen::half>(const Eigen::half* input, bool* output,
int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float v = LoadFloat(&input[idx]);
output[idx] = IsNanValue(v);
}
}


template <>
__global__ void IsNanKernel<bfloat16>(const bfloat16* input, bool* output,
int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float v = LoadFloat(&input[idx]);
output[idx] = IsNanValue(v);
}
}

// --------- Launch ---------
template <typename T>
void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream) {
if (n <= 0) return;

int threads = 256;
int blocks = (n + threads - 1) / threads;

IsNanKernel<T><<<blocks, threads, 0, stream>>>(input, output, n);


musaError_t err = musaGetLastError();
(void)err;
}


template void LaunchIsNan<float>(const float*, bool*, int, musaStream_t);
template void LaunchIsNan<double>(const double*, bool*, int, musaStream_t);
template void LaunchIsNan<Eigen::half>(const Eigen::half*, bool*, int,
musaStream_t);
template void LaunchIsNan<bfloat16>(const bfloat16*, bool*, int, musaStream_t);

} // namespace musa
} // namespace tensorflow
61 changes: 61 additions & 0 deletions musa_ext/kernels/musa_isnan_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "mu/device/musa_device.h"
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "utils_op.h"

namespace tensorflow {
namespace musa {
template <typename T>
void LaunchIsNan(const T* input, bool* output, int n, musaStream_t stream);
} // namespace musa
} // namespace tensorflow

namespace tensorflow {
namespace musa {

template <typename T>
class MusaIsNanOp : public MusaOpKernel {
public:
explicit MusaIsNanOp(OpKernelConstruction* ctx) : MusaOpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
const Tensor& x = ctx->input(0);

Tensor* y = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));

const int64_t n64 = y->NumElements();
if (n64 == 0) return;

OP_REQUIRES(ctx,
n64 <= static_cast<int64_t>(std::numeric_limits<int>::max()),
errors::InvalidArgument(
"IsNan: tensor is too large, num_elements=", n64));

const int n = static_cast<int>(n64);

const T* x_ptr = x.flat<T>().data();
bool* y_ptr = y->flat<bool>().data();

auto* device = GetDeviceByCtx(ctx);
auto stream = device->GetStream();

LaunchIsNan<T>(x_ptr, y_ptr, n, stream);
}
};

#define REGISTER_MUSA_ISNAN(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("IsNan").Device(DEVICE_MTGPU).TypeConstraint<TYPE>("T"), \
MusaIsNanOp<TYPE>);

REGISTER_MUSA_ISNAN(float);
REGISTER_MUSA_ISNAN(double);
REGISTER_MUSA_ISNAN(Eigen::half);
REGISTER_MUSA_ISNAN(bfloat16);

#undef REGISTER_MUSA_ISNAN

} // namespace musa
} // namespace tensorflow
119 changes: 119 additions & 0 deletions test/isnan_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#Copyright 2026 The TensorFlow MUSA Authors.All Rights Reserved.
#
#Licensed under the Apache License, Version 2.0(the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#http: // www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#== == == == == == == == == == == == == == == == == == == == == == == == == == \

"""Tests for MUSA IsNan operator.

This test assumes:
- TensorFlow core has registered the Op 'IsNan' (math_ops.cc).
- The MUSA plugin registers a DEVICE_MTGPU kernel for 'IsNan'.
"""
import numpy as np
import tensorflow as tf

from musa_test_utils import MUSATestCase


class IsNanOpTest(MUSATestCase):
"""Tests for MUSA IsNan operator."""

def _make_input(self, shape, dtype, inject_nan=True, fill_value=None, include_inf=False):
"""Create a numpy input array for a given TF dtype."""
np_dtype = np.float32 if dtype == tf.bfloat16 else dtype.as_numpy_dtype

if fill_value is not None:
x_np = np.full(shape, fill_value, dtype=np_dtype)
else:
x_np = np.random.uniform(-1.0, 1.0, size=shape).astype(np_dtype)

if include_inf and x_np.size > 0:
x_np.flat[0] = np.inf
if x_np.size > 1:
x_np.flat[1] = -np.inf

if inject_nan and x_np.size > 0:
#Put NaNs in deterministic positions
x_np.flat[0] = np.nan
x_np.flat[x_np.size // 2] = np.nan
x_np.flat[-1] = np.nan

return x_np

def _test_isnan(self, shape, dtype, inject_nan=True, fill_value=None, include_inf=False):
"""Test IsNan operation with given shape and dtype."""
x_np = self._make_input(shape, dtype, inject_nan=inject_nan,
fill_value=fill_value, include_inf=include_inf)
x_tf = tf.constant(x_np, dtype=dtype)

def isnan_proxy(x):
return tf.cast(tf.math.is_nan(x), tf.float32)

self._compare_cpu_musa_results(isnan_proxy, [x_tf], dtype, rtol=0.0, atol=0.0)


with tf.device("/CPU:0"):
cpu_bool = tf.math.is_nan(x_tf)
with tf.device("/device:MUSA:0"):
musa_bool = tf.math.is_nan(x_tf)

self.assertEqual(cpu_bool.dtype, tf.bool)
self.assertEqual(musa_bool.dtype, tf.bool)
self.assertAllEqual(cpu_bool.shape.as_list(), x_tf.shape.as_list())
self.assertAllEqual(musa_bool.shape.as_list(), x_tf.shape.as_list())
self.assertAllEqual(cpu_bool.numpy(), musa_bool.numpy())

def testIsNanSmall(self):
"""Small tensor correctness."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
self._test_isnan([10], dtype)

def testIsNanLarge(self):
"""Larger tensor correctness."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
self._test_isnan([256, 256], dtype)

def testIsNanEmptyTensor(self):
"""Empty tensors should return empty bool tensors with same shape."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
self._test_isnan([0], dtype, inject_nan=False)
self._test_isnan([0, 5], dtype, inject_nan=False)

def testIsNanNoNaNs(self):
"""If there are no NaNs, all outputs should be False."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
self._test_isnan([1024], dtype, inject_nan=False, include_inf=False)

def testIsNanAllNaNs(self):
"""All NaNs should yield all True."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
#For float16 / bf16, np.nan is representable; TF will carry NaN.
self._test_isnan([128], dtype, inject_nan=False, fill_value=np.nan)

def testIsNanWithInfs(self):
"""Infs are not NaNs; only NaNs should be True."""
for dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64]:
self._test_isnan([64], dtype, inject_nan=True, include_inf=True)

def testIsNanInvalidDType(self):
"""IsNan should reject non-floating types per TF op definition."""
for dtype in [tf.int32, tf.int64]:
x = tf.constant([1, 2, 3], dtype=dtype)

with self.assertRaises((TypeError, tf.errors.InvalidArgumentError)):
with tf.device("/device:MUSA:0"):
_ = tf.math.is_nan(x)


if __name__ == "__main__":
tf.test.main()