File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -60,9 +60,9 @@ void silu_mul_cuda(
6060 uint64_t K
6161) {
6262 auto threads = 1024 ;
63- auto blocks = ((K + threads - 1 ) / threads) * MN;
63+ auto blocks = MN;
6464
65- silu_mul<<<blocks, threads >>> (
65+ silu_mul<<<blocks, threads>>> (
6666 (half *)x_gpu,
6767 (half *)o_gpu,
6868 K,
Original file line number Diff line number Diff line change @@ -65,9 +65,9 @@ void silu_mul_hip(
6565 hipError_t ret;
6666
6767 auto threads = 1024 ;
68- auto blocks = ((K + threads - 1 ) / threads) * MN;
68+ auto blocks = MN;
6969
70- silu_mul<<<blocks, threads >>>(
70+ silu_mul<<<blocks, threads>>>(
7171 (half *)x_gpu,
7272 (half *)o_gpu,
7373 K,
Original file line number Diff line number Diff line change @@ -74,20 +74,18 @@ __global__ void silu_mul(
7474 const int K ,
7575 const int MN
7676) {
77- int i = blockIdx .x * blockDim .x + threadIdx .x ;
78- int j = i / K ;
79- int k = i - (j * K );
77+ int j = blockIdx .x ;
8078
81- if ( k < K && j < MN ) {
82- k += j * (K * 2 );
79+ for ( int k = threadIdx . x ; k < K ; k += blockDim . x ) {
80+ int i = k + j * (K * 2 );
8381
84- half y = x_gpu [k ];
85- half gate = x_gpu [k + K ];
82+ half y = x_gpu [i ];
83+ half gate = x_gpu [i + K ];
8684
8785 float g = __half2float (gate );
8886 float silu = g / (1.0f + __expf (- g ));
8987
90- o_gpu [i ] = __float2half (silu * __half2float (y ));
88+ o_gpu [k + j * K ] = __float2half (silu * __half2float (y ));
9189 }
9290}
9391
Original file line number Diff line number Diff line change @@ -75,20 +75,18 @@ __global__ void silu_mul(
7575 const int K ,
7676 const int MN
7777) {
78- int i = blockIdx .x * blockDim .x + threadIdx .x ;
79- int j = i / K ;
80- int k = i - (j * K );
78+ int j = blockIdx .x ;
8179
82- if ( k < K && j < MN ) {
83- k += j * (K * 2 );
80+ for ( int k = threadIdx . x ; k < K ; k += blockDim . x ) {
81+ int i = k + j * (K * 2 );
8482
85- half y = x_gpu [k ];
86- half gate = x_gpu [k + K ];
83+ half y = x_gpu [i ];
84+ half gate = x_gpu [i + K ];
8785
8886 float g = __half2float (gate );
8987 float silu = g / (1.0f + __expf (- g ));
9088
91- o_gpu [i ] = __float2half (silu * __half2float (y ));
89+ o_gpu [k + j * K ] = __float2half (silu * __half2float (y ));
9290 }
9391}
9492
You can’t perform that action at this time.
0 commit comments