Skip to content

Commit 7a41e33

Browse files
committed
opt
1 parent b276855 commit 7a41e33

4 files changed

Lines changed: 16 additions & 20 deletions

File tree

src/nn_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ void silu_mul_cuda(
6060
uint64_t K
6161
) {
6262
auto threads = 1024;
63-
auto blocks = ((K + threads - 1) / threads) * MN;
63+
auto blocks = MN;
6464

65-
silu_mul<<<blocks, threads >>>(
65+
silu_mul<<<blocks, threads>>>(
6666
(half *)x_gpu,
6767
(half *)o_gpu,
6868
K,

src/nn_hip.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ void silu_mul_hip(
6565
hipError_t ret;
6666

6767
auto threads = 1024;
68-
auto blocks = ((K + threads - 1) / threads) * MN;
68+
auto blocks = MN;
6969

70-
silu_mul<<<blocks, threads >>>(
70+
silu_mul<<<blocks, threads>>>(
7171
(half *)x_gpu,
7272
(half *)o_gpu,
7373
K,

src/nn_kernel_cuda.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,18 @@ __global__ void silu_mul(
7474
const int K,
7575
const int MN
7676
) {
77-
int i = blockIdx.x * blockDim.x + threadIdx.x;
78-
int j = i / K;
79-
int k = i - (j * K);
77+
int j = blockIdx.x;
8078

81-
if (k < K && j < MN) {
82-
k += j * (K * 2);
79+
for (int k = threadIdx.x; k < K; k += blockDim.x) {
80+
int i = k + j * (K * 2);
8381

84-
half y = x_gpu[k];
85-
half gate = x_gpu[k + K];
82+
half y = x_gpu[i];
83+
half gate = x_gpu[i + K];
8684

8785
float g = __half2float(gate);
8886
float silu = g / (1.0f + __expf(-g));
8987

90-
o_gpu[i] = __float2half(silu * __half2float(y));
88+
o_gpu[k + j * K] = __float2half(silu * __half2float(y));
9189
}
9290
}
9391

src/nn_kernel_hip.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,18 @@ __global__ void silu_mul(
7575
const int K,
7676
const int MN
7777
) {
78-
int i = blockIdx.x * blockDim.x + threadIdx.x;
79-
int j = i / K;
80-
int k = i - (j * K);
78+
int j = blockIdx.x;
8179

82-
if (k < K && j < MN) {
83-
k += j * (K * 2);
80+
for (int k = threadIdx.x; k < K; k += blockDim.x) {
81+
int i = k + j * (K * 2);
8482

85-
half y = x_gpu[k];
86-
half gate = x_gpu[k + K];
83+
half y = x_gpu[i];
84+
half gate = x_gpu[i + K];
8785

8886
float g = __half2float(gate);
8987
float silu = g / (1.0f + __expf(-g));
9088

91-
o_gpu[i] = __float2half(silu * __half2float(y));
89+
o_gpu[k + j * K] = __float2half(silu * __half2float(y));
9290
}
9391
}
9492

0 commit comments

Comments
 (0)