Skip to content

Commit 524e635

Browse files
committed
add support and tests for __nvvm_sin/cos_approx intrinsics
1 parent 02185c0 commit 524e635

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1223,11 +1223,24 @@ def SIN_APPROX_f32 :
12231223
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12241224
"sin.approx$ftz.f32",
12251225
[(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
1226+
1227+
// Patterns for NVVM sin intrinsics
1228+
def : Pat<(f32 (int_nvvm_sin_approx_f f32:$a)),
1229+
(SIN_APPROX_f32 f32:$a, 0)>;
1230+
def : Pat<(f32 (int_nvvm_sin_approx_ftz_f f32:$a)),
1231+
(SIN_APPROX_f32 f32:$a, 1)>;
1232+
12261233
def COS_APPROX_f32 :
12271234
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12281235
"cos.approx$ftz.f32",
12291236
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
12301237

1238+
// Patterns for NVVM cos intrinsics
1239+
def : Pat<(f32 (int_nvvm_cos_approx_f f32:$a)),
1240+
(COS_APPROX_f32 f32:$a, 0)>;
1241+
def : Pat<(f32 (int_nvvm_cos_approx_ftz_f f32:$a)),
1242+
(COS_APPROX_f32 f32:$a, 1)>;
1243+
12311244
// NOTE: tanh.approx doesn't support the FTZ flag for f16/f16x2
12321245
def TANH_APPROX_f16 :
12331246
BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "tanh.approx.f16",
@@ -1238,7 +1251,7 @@ def TANH_APPROX_f16x2 :
12381251
BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f16x2",
12391252
[(set v2f16:$dst, (UnaryOpAllowsApproxFn<ftanh> v2f16:$src))]>,
12401253
Requires<[hasPTX<70>, hasSM<75>]>;
1241-
1254+
12421255
def TANH_APPROX_f32 :
12431256
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12441257
"tanh.approx$ftz.f32",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: %{build} -fsycl-device-code-split=per_kernel -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// Tests NVVM intrinsics for sin and cos approximations
5+
6+
#include <cassert>
7+
#include <cmath>
8+
#include <sycl/detail/core.hpp>
9+
10+
// Forward declarations of NVVM intrinsics
11+
extern "C" {
12+
float __nvvm_sin_approx_f(float);
13+
float __nvvm_sin_approx_ftz_f(float);
14+
float __nvvm_cos_approx_f(float);
15+
float __nvvm_cos_approx_ftz_f(float);
16+
}
17+
18+
constexpr float TOLERANCE = 0.01f; // 1% tolerance for approximations
19+
20+
template <typename Func>
21+
void test_approx(sycl::queue &q, Func intrinsic_func, const char *name,
22+
float input, float expected) {
23+
float result = 0.0f;
24+
25+
{
26+
sycl::buffer<float, 1> buf_result(&result, sycl::range<1>(1));
27+
q.submit([&](sycl::handler &cgh) {
28+
auto acc_result =
29+
buf_result.template get_access<sycl::access::mode::write>(cgh);
30+
cgh.single_task([=]() { acc_result[0] = intrinsic_func(input); });
31+
}).wait();
32+
}
33+
34+
float error = std::abs(result - expected);
35+
assert(error < TOLERANCE && name && " approximation out of tolerance");
36+
}
37+
38+
int main() {
39+
sycl::queue q;
40+
41+
// Test values
42+
const float pi = 3.14159265f;
43+
const float test_values[] = {0.0f, pi / 6.0f, pi / 4.0f, pi / 3.0f,
44+
pi / 2.0f, pi, 2.0f * pi};
45+
46+
// Expected sin values
47+
const float expected_sin[] = {0.0f, 0.5f, 0.707107f, 0.866025f,
48+
1.0f, 0.0f, 0.0f};
49+
50+
// Expected cos values
51+
const float expected_cos[] = {1.0f, 0.866025f, 0.707107f, 0.5f,
52+
0.0f, -1.0f, 1.0f};
53+
54+
// Test __nvvm_sin_approx_f
55+
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); ++i) {
56+
test_approx(q, __nvvm_sin_approx_f, "sin", test_values[i], expected_sin[i]);
57+
}
58+
59+
// Test __nvvm_sin_approx_ftz_f
60+
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); ++i) {
61+
test_approx(q, __nvvm_sin_approx_ftz_f, "sin_ftz", test_values[i],
62+
expected_sin[i]);
63+
}
64+
65+
// Test __nvvm_cos_approx_f
66+
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); ++i) {
67+
test_approx(q, __nvvm_cos_approx_f, "cos", test_values[i], expected_cos[i]);
68+
}
69+
70+
// Test __nvvm_cos_approx_ftz_f
71+
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); ++i) {
72+
test_approx(q, __nvvm_cos_approx_ftz_f, "cos_ftz", test_values[i],
73+
expected_cos[i]);
74+
}
75+
76+
return 0;
77+
}

0 commit comments

Comments
 (0)