-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[LLVM][Codegen] Cast NaN to bool gives true #18646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,6 +378,31 @@ def check_llvm(n): | |
| check_llvm(64) | ||
|
|
||
|
|
||
| @tvm.testing.requires_llvm | ||
| def test_llvm_cast_float_to_bool(): | ||
| a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype="float32") | ||
| n = a_np.shape[0] | ||
|
|
||
| A = te.placeholder((n,), name="A", dtype="float32") | ||
| C = te.compute((n,), lambda i: A[i].astype("bool"), name="C") | ||
|
|
||
| # Convert to TIR and create schedule | ||
| mod = te.create_prim_func([A, C]) | ||
| sch = tir.Schedule(mod) | ||
|
|
||
| # build and invoke the kernel. | ||
| f = tvm.compile(sch.mod, target="llvm") | ||
| dev = tvm.cpu(0) | ||
|
|
||
| # launch the kernel. | ||
| a = tvm.runtime.tensor(a_np, dev) | ||
| c = tvm.runtime.empty((n,), dtype="bool", device=dev) | ||
| f(a, c) | ||
| c_np = np.array([False, True, True, True], dtype="bool") | ||
|
|
||
| tvm.testing.assert_allclose(c.numpy(), c_np) | ||
|
Comment on lines
+381
to
+403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great test case that covers the essential scenarios for casting floats to booleans. To make it even more comprehensive, I suggest parameterizing it to run against multiple float dtypes ( @tvm.testing.requires_llvm
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
def test_llvm_cast_float_to_bool(dtype):
if dtype == "float16" and tvm.target.codegen.llvm_version_major() < 8:
pytest.skip("float16 support requires LLVM 8 or greater")
a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype=dtype)
n = a_np.shape[0]
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute((n,), lambda i: A[i].astype("bool"), name="C")
# Convert to TIR and create schedule
mod = te.create_prim_func([A, C])
sch = tir.Schedule(mod)
# build and invoke the kernel.
f = tvm.compile(sch.mod, target="llvm")
dev = tvm.cpu(0)
# launch the kernel.
a = tvm.runtime.tensor(a_np, dev)
c = tvm.runtime.empty((n,), dtype="bool", device=dev)
f(a, c)
c_np = np.array([False, True, True, True], dtype="bool")
tvm.testing.assert_allclose(c.numpy(), c_np) |
||
|
|
||
|
|
||
| @tvm.testing.requires_llvm | ||
| def test_rank_zero(): | ||
| def check_llvm(n): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use tvmscript instead of the te schedule in as in other part of the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will edit test code in new PR