Skip to content

Commit bf99ac5

Browse files
committed
own nn stuff
1 parent 2171f09 commit bf99ac5

6 files changed

Lines changed: 486 additions & 76 deletions

File tree

include/openfish/openfish.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,35 @@ void openfish_rotary_emb_gpu(
101101
int stride_head
102102
);
103103

104+
void openfish_silu_mul_gpu(
105+
void *x_gpu,
106+
void *o_gpu,
107+
uint64_t MN,
108+
uint64_t K
109+
);
110+
111+
void openfish_rmsnorm_gpu(
112+
const void* input,
113+
const void* residual,
114+
const void* weight,
115+
void* output,
116+
int MN,
117+
int K,
118+
float alpha,
119+
float eps
120+
);
121+
122+
void openfish_rmsnorm_quant_gpu(
123+
const void* input,
124+
const void* weight,
125+
void* residual,
126+
void* residual_scale,
127+
int MN,
128+
int K,
129+
float alpha,
130+
float eps
131+
);
132+
104133
#ifdef __cplusplus
105134
}
106135
#endif

src/nn_cuda.cu

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,81 @@
11
#include "nn_cuda.h"
22
#include "error.h"
33
#include "cuda_utils.cuh"
4-
#include "rotary_emb_cuda.cuh"
4+
#include "nn_cuda.cuh"
55

66
#include <openfish/openfish_error.h>
77

88
#include <cuda_fp16.h>
99

10+
void rmsnorm_quant_cuda(
11+
const void* input,
12+
const void* weight,
13+
void* residual,
14+
void* residual_scale,
15+
int MN,
16+
int K,
17+
float alpha,
18+
float eps
19+
) {
20+
ASSERT(K <= 1024);
21+
22+
int threads = K;
23+
int blocks = MN;
24+
25+
rmsnorm_quant<<<blocks, threads>>>(
26+
(half *)input, (half *)weight, (int8_t *)residual, (float *)residual_scale, MN, K, alpha, eps
27+
);
28+
checkCudaError();
29+
cudaDeviceSynchronize();
30+
checkCudaError();
31+
}
32+
33+
void rmsnorm_cuda(
34+
const void* input,
35+
const void* residual,
36+
const void* weight,
37+
void* output,
38+
int MN,
39+
int K,
40+
float alpha,
41+
float eps
42+
) {
43+
ASSERT(K <= 1024);
44+
45+
int threads = K;
46+
int blocks = MN;
47+
48+
rmsnorm<<<blocks, threads>>>(
49+
(half *)input, (half *)residual, (half *)weight, (half *)output, MN, K, alpha, eps
50+
);
51+
checkCudaError();
52+
cudaDeviceSynchronize();
53+
checkCudaError();
54+
}
55+
56+
void silu_mul_cuda(
57+
void *x_gpu,
58+
void *o_gpu,
59+
uint64_t MN,
60+
uint64_t K
61+
) {
62+
dim3 block(32, 32);
63+
dim3 grid(
64+
(K + block.x - 1) / block.x,
65+
(MN + block.y - 1) / block.y
66+
);
67+
68+
silu_mul<<<grid, block>>>(
69+
(half *)x_gpu,
70+
(half *)o_gpu,
71+
K,
72+
MN
73+
);
74+
checkCudaError();
75+
cudaDeviceSynchronize();
76+
checkCudaError();
77+
}
78+
1079
void rotary_emb_cuda(
1180
void *x_gpu,
1281
void *sin_gpu,

src/nn_cuda.cuh

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
// The MIT License (MIT)
2+
3+
// Copyright (c) 2025 Bonson Wong
4+
5+
// Permission is hereby granted, free of charge, to any person obtaining a copy
6+
// of this software and associated documentation files (the "Software"), to deal
7+
// in the Software without restriction, including without limitation the rights
8+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
// copies of the Software, and to permit persons to whom the Software is
10+
// furnished to do so, subject to the following conditions:
11+
12+
// The above copyright notice and this permission notice shall be included in
13+
// all copies or substantial portions of the Software.
14+
15+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
// THE SOFTWARE.
22+
23+
#ifndef NN_CUDA_CUH
24+
#define NN_CUDA_CUH
25+
26+
#include <math.h>
27+
#include <float.h>
28+
#include <cuda_fp16.h>
29+
#include <stdint.h>
30+
31+
#ifdef __cplusplus
32+
extern "C" {
33+
#endif
34+
35+
__global__ void rotary_emb(
36+
half *x,
37+
float *_cos,
38+
float *_sin,
39+
const uint64_t seqlen,
40+
const uint64_t stride_batch,
41+
const uint64_t stride_seq,
42+
const uint64_t stride_head,
43+
const uint64_t rotary_half
44+
) {
45+
const uint64_t batch = blockIdx.x;
46+
const uint64_t head = blockIdx.y;
47+
const uint64_t rot = threadIdx.x;
48+
const uint64_t tid = threadIdx.y;
49+
const uint64_t nthreads = blockDim.y;
50+
51+
if (tid >= seqlen) return;
52+
53+
half *_o0 = x + (batch * stride_batch) + (head * stride_head) + rot;
54+
half *_o1 = x + (batch * stride_batch) + (head * stride_head) + rotary_half + rot;
55+
56+
for (int seq = tid; seq < seqlen; seq += nthreads) {
57+
float cos = *(_cos + (seq * rotary_half) + rot);
58+
float sin = *(_sin + (seq * rotary_half) + rot);
59+
60+
half *o0 = _o0 + (seq * stride_seq);
61+
half *o1 = _o1 + (seq * stride_seq);
62+
63+
float x0 = __half2float(*o0);
64+
float x1 = __half2float(*o1);
65+
66+
*o0 = __float2half(x0 * cos - x1 * sin);
67+
*o1 = __float2half(x0 * sin + x1 * cos);
68+
}
69+
}
70+
71+
__global__ void silu_mul(
72+
half *x_gpu,
73+
half *o_gpu,
74+
const uint64_t B,
75+
const uint64_t MN
76+
) {
77+
uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
78+
const uint64_t j = blockIdx.y * blockDim.y + threadIdx.y;
79+
80+
if (i < B && j < MN) {
81+
uint64_t k = i + j * B;
82+
i += j * (B * 2);
83+
84+
half y = x_gpu[i];
85+
half gate = x_gpu[i + B];
86+
87+
float g = __half2float(gate);
88+
float silu = g / (1.0f + __expf(-g));
89+
90+
o_gpu[k] = __float2half(silu * __half2float(y));
91+
}
92+
}
93+
94+
__global__ void rmsnorm(
95+
const half* input,
96+
const half* residual,
97+
const half* weight,
98+
half* output,
99+
int batch_size,
100+
int hidden_dim,
101+
float alpha,
102+
float eps
103+
) {
104+
int row = blockIdx.x; // Which sequence/batch element
105+
106+
if (row >= batch_size) return;
107+
108+
const half* x = input + row * hidden_dim;
109+
const half* res = residual + row * hidden_dim;
110+
half* y = output + row * hidden_dim;
111+
112+
// Step 1: Compute sum of squares using shared memory reduction
113+
__shared__ float shared_sum[32]; // For warp reduction
114+
115+
float thread_sum = 0.0f;
116+
float x_new; // if this for loop happens more than once it will break, in this case we need to cache more than one x
117+
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
118+
float val = __half2float(x[i]) + (__half2float(res[i]) * alpha);
119+
x_new = val;
120+
thread_sum += val * val;
121+
}
122+
123+
// Warp-level reduction
124+
int warp_id = threadIdx.x / 32;
125+
int lane_id = threadIdx.x % 32;
126+
127+
// Reduce within warp
128+
for (int offset = 16; offset > 0; offset /= 2) {
129+
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
130+
}
131+
132+
// First thread in each warp writes to shared memory
133+
if (lane_id == 0) {
134+
shared_sum[warp_id] = thread_sum;
135+
}
136+
__syncthreads();
137+
138+
// First warp reduces the warp sums
139+
float sum_sq = 0.0f;
140+
if (threadIdx.x < 32) {
141+
int num_warps = (blockDim.x + 31) / 32;
142+
sum_sq = (threadIdx.x < num_warps) ? shared_sum[threadIdx.x] : 0.0f;
143+
144+
for (int offset = 16; offset > 0; offset /= 2) {
145+
sum_sq += __shfl_down_sync(0xffffffff, sum_sq, offset);
146+
}
147+
}
148+
149+
// Broadcast RMS to all threads
150+
__shared__ float rms_shared;
151+
if (threadIdx.x == 0) {
152+
float mean_sq = sum_sq / hidden_dim;
153+
rms_shared = rsqrtf(mean_sq + eps); // 1 / sqrt(mean_sq + eps)
154+
}
155+
__syncthreads();
156+
157+
float rms_inv = rms_shared;
158+
159+
// Step 2: Normalize and apply weight
160+
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
161+
float w = __half2float(weight[i]);
162+
y[i] = __float2half(x_new * rms_inv * w);
163+
}
164+
}
165+
166+
__global__ void rmsnorm_quant(
167+
const half* input,
168+
const half* weight,
169+
int8_t* residual,
170+
float* residual_scale,
171+
int batch_size,
172+
int hidden_dim,
173+
float alpha,
174+
float eps
175+
) {
176+
int row = blockIdx.x; // Which sequence/batch element
177+
int idx = threadIdx.x;
178+
179+
if (row >= batch_size) return;
180+
181+
const half* inp = input + row * hidden_dim;
182+
int8_t* res = residual + row * hidden_dim;
183+
float* res_scale = residual_scale + row;
184+
float w = __half2float(weight[idx]);
185+
186+
// Step 1: Compute sum of squares using shared memory reduction
187+
__shared__ float shared_sum[32]; // For warp reduction
188+
189+
float thread_sum = 0.0f;
190+
float val = __half2float(inp[idx]) + (((float)res[idx] * (*res_scale)) * alpha);
191+
thread_sum += val * val;
192+
193+
// Warp-level reduction
194+
int warp_id = threadIdx.x / 32;
195+
int lane_id = threadIdx.x % 32;
196+
197+
// Reduce within warp
198+
for (int offset = 16; offset > 0; offset /= 2) {
199+
thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset);
200+
}
201+
202+
// First thread in each warp writes to shared memory
203+
if (lane_id == 0) {
204+
shared_sum[warp_id] = thread_sum;
205+
}
206+
__syncthreads();
207+
208+
// First warp reduces the warp sums
209+
float sum_sq = 0.0f;
210+
if (threadIdx.x < 32) {
211+
int num_warps = (blockDim.x + 31) / 32;
212+
sum_sq = (threadIdx.x < num_warps) ? shared_sum[threadIdx.x] : 0.0f;
213+
214+
for (int offset = 16; offset > 0; offset /= 2) {
215+
sum_sq += __shfl_down_sync(0xffffffff, sum_sq, offset);
216+
}
217+
}
218+
219+
// Broadcast RMS to all threads
220+
__shared__ float rms_shared;
221+
if (threadIdx.x == 0) {
222+
float mean_sq = sum_sq / hidden_dim;
223+
rms_shared = rsqrtf(mean_sq + eps); // 1 / sqrt(mean_sq + eps)
224+
}
225+
__syncthreads();
226+
227+
float rms_inv = rms_shared;
228+
229+
// Step 2: Find max absolute value for output quantization
230+
__shared__ float shared_max[32];
231+
232+
float thread_max = 0.0f;
233+
float normalized = val * rms_inv * w;
234+
thread_max = fmaxf(thread_max, fabsf(normalized));
235+
236+
// Reduce to find max
237+
for (int offset = 16; offset > 0; offset /= 2) {
238+
thread_max = fmaxf(thread_max, __shfl_down_sync(0xffffffff, thread_max, offset));
239+
}
240+
241+
if (lane_id == 0) {
242+
shared_max[warp_id] = thread_max;
243+
}
244+
__syncthreads();
245+
246+
float abs_max = 0.0f;
247+
if (threadIdx.x < 32) {
248+
int num_warps = (blockDim.x + 31) / 32;
249+
abs_max = (threadIdx.x < num_warps) ? shared_max[threadIdx.x] : 0.0f;
250+
251+
for (int offset = 16; offset > 0; offset /= 2) {
252+
abs_max = fmaxf(abs_max, __shfl_down_sync(0xffffffff, abs_max, offset));
253+
}
254+
}
255+
256+
// write to quant scale
257+
__shared__ float quant_scale_shared;
258+
if (threadIdx.x == 0) {
259+
quant_scale_shared = (abs_max > 0.0f) ? (127.0f / abs_max) : 1.0f;
260+
*res_scale = 1.0f / quant_scale_shared;
261+
}
262+
__syncthreads();
263+
264+
265+
// clamp and write quantized norm
266+
float quant_scale = quant_scale_shared;
267+
int quantized = __float2int_rn(normalized * quant_scale);
268+
quantized = max(-127, min(127, quantized));
269+
res[idx] = (int8_t)quantized;
270+
}
271+
272+
#ifdef __cplusplus
273+
}
274+
#endif
275+
276+
#endif // NN_CUDA_CUH

0 commit comments

Comments
 (0)