1+ // The MIT License (MIT)
2+
3+ // Copyright (c) 2025 Bonson Wong
4+
5+ // Permission is hereby granted, free of charge, to any person obtaining a copy
6+ // of this software and associated documentation files (the "Software"), to deal
7+ // in the Software without restriction, including without limitation the rights
8+ // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+ // copies of the Software, and to permit persons to whom the Software is
10+ // furnished to do so, subject to the following conditions:
11+
12+ // The above copyright notice and this permission notice shall be included in
13+ // all copies or substantial portions of the Software.
14+
15+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+ // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+ // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+ // THE SOFTWARE.
22+
23+ #ifndef NN_CUDA_CUH
24+ #define NN_CUDA_CUH
25+
26+ #include < math.h>
27+ #include < float.h>
28+ #include < cuda_fp16.h>
29+ #include < stdint.h>
30+
31+ #ifdef __cplusplus
32+ extern " C" {
33+ #endif
34+
35+ __global__ void rotary_emb (
36+ half *x,
37+ float *_cos,
38+ float *_sin,
39+ const uint64_t seqlen,
40+ const uint64_t stride_batch,
41+ const uint64_t stride_seq,
42+ const uint64_t stride_head,
43+ const uint64_t rotary_half
44+ ) {
45+ const uint64_t batch = blockIdx .x ;
46+ const uint64_t head = blockIdx .y ;
47+ const uint64_t rot = threadIdx .x ;
48+ const uint64_t tid = threadIdx .y ;
49+ const uint64_t nthreads = blockDim .y ;
50+
51+ if (tid >= seqlen) return ;
52+
53+ half *_o0 = x + (batch * stride_batch) + (head * stride_head) + rot;
54+ half *_o1 = x + (batch * stride_batch) + (head * stride_head) + rotary_half + rot;
55+
56+ for (int seq = tid; seq < seqlen; seq += nthreads) {
57+ float cos = *(_cos + (seq * rotary_half) + rot);
58+ float sin = *(_sin + (seq * rotary_half) + rot);
59+
60+ half *o0 = _o0 + (seq * stride_seq);
61+ half *o1 = _o1 + (seq * stride_seq);
62+
63+ float x0 = __half2float (*o0);
64+ float x1 = __half2float (*o1);
65+
66+ *o0 = __float2half (x0 * cos - x1 * sin);
67+ *o1 = __float2half (x0 * sin + x1 * cos);
68+ }
69+ }
70+
71+ __global__ void silu_mul (
72+ half *x_gpu,
73+ half *o_gpu,
74+ const uint64_t B,
75+ const uint64_t MN
76+ ) {
77+ uint64_t i = blockIdx .x * blockDim .x + threadIdx .x ;
78+ const uint64_t j = blockIdx .y * blockDim .y + threadIdx .y ;
79+
80+ if (i < B && j < MN) {
81+ uint64_t k = i + j * B;
82+ i += j * (B * 2 );
83+
84+ half y = x_gpu[i];
85+ half gate = x_gpu[i + B];
86+
87+ float g = __half2float (gate);
88+ float silu = g / (1 .0f + __expf (-g));
89+
90+ o_gpu[k] = __float2half (silu * __half2float (y));
91+ }
92+ }
93+
94+ __global__ void rmsnorm (
95+ const half* input,
96+ const half* residual,
97+ const half* weight,
98+ half* output,
99+ int batch_size,
100+ int hidden_dim,
101+ float alpha,
102+ float eps
103+ ) {
104+ int row = blockIdx .x ; // Which sequence/batch element
105+
106+ if (row >= batch_size) return ;
107+
108+ const half* x = input + row * hidden_dim;
109+ const half* res = residual + row * hidden_dim;
110+ half* y = output + row * hidden_dim;
111+
112+ // Step 1: Compute sum of squares using shared memory reduction
113+ __shared__ float shared_sum[32 ]; // For warp reduction
114+
115+ float thread_sum = 0 .0f ;
116+ float x_new; // if this for loop happens more than once it will break, in this case we need to cache more than one x
117+ for (int i = threadIdx .x ; i < hidden_dim; i += blockDim .x ) {
118+ float val = __half2float (x[i]) + (__half2float (res[i]) * alpha);
119+ x_new = val;
120+ thread_sum += val * val;
121+ }
122+
123+ // Warp-level reduction
124+ int warp_id = threadIdx .x / 32 ;
125+ int lane_id = threadIdx .x % 32 ;
126+
127+ // Reduce within warp
128+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
129+ thread_sum += __shfl_down_sync (0xffffffff , thread_sum, offset);
130+ }
131+
132+ // First thread in each warp writes to shared memory
133+ if (lane_id == 0 ) {
134+ shared_sum[warp_id] = thread_sum;
135+ }
136+ __syncthreads ();
137+
138+ // First warp reduces the warp sums
139+ float sum_sq = 0 .0f ;
140+ if (threadIdx .x < 32 ) {
141+ int num_warps = (blockDim .x + 31 ) / 32 ;
142+ sum_sq = (threadIdx .x < num_warps) ? shared_sum[threadIdx .x ] : 0 .0f ;
143+
144+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
145+ sum_sq += __shfl_down_sync (0xffffffff , sum_sq, offset);
146+ }
147+ }
148+
149+ // Broadcast RMS to all threads
150+ __shared__ float rms_shared;
151+ if (threadIdx .x == 0 ) {
152+ float mean_sq = sum_sq / hidden_dim;
153+ rms_shared = rsqrtf (mean_sq + eps); // 1 / sqrt(mean_sq + eps)
154+ }
155+ __syncthreads ();
156+
157+ float rms_inv = rms_shared;
158+
159+ // Step 2: Normalize and apply weight
160+ for (int i = threadIdx .x ; i < hidden_dim; i += blockDim .x ) {
161+ float w = __half2float (weight[i]);
162+ y[i] = __float2half (x_new * rms_inv * w);
163+ }
164+ }
165+
166+ __global__ void rmsnorm_quant (
167+ const half* input,
168+ const half* weight,
169+ int8_t * residual,
170+ float * residual_scale,
171+ int batch_size,
172+ int hidden_dim,
173+ float alpha,
174+ float eps
175+ ) {
176+ int row = blockIdx .x ; // Which sequence/batch element
177+ int idx = threadIdx .x ;
178+
179+ if (row >= batch_size) return ;
180+
181+ const half* inp = input + row * hidden_dim;
182+ int8_t * res = residual + row * hidden_dim;
183+ float * res_scale = residual_scale + row;
184+ float w = __half2float (weight[idx]);
185+
186+ // Step 1: Compute sum of squares using shared memory reduction
187+ __shared__ float shared_sum[32 ]; // For warp reduction
188+
189+ float thread_sum = 0 .0f ;
190+ float val = __half2float (inp[idx]) + (((float )res[idx] * (*res_scale)) * alpha);
191+ thread_sum += val * val;
192+
193+ // Warp-level reduction
194+ int warp_id = threadIdx .x / 32 ;
195+ int lane_id = threadIdx .x % 32 ;
196+
197+ // Reduce within warp
198+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
199+ thread_sum += __shfl_down_sync (0xffffffff , thread_sum, offset);
200+ }
201+
202+ // First thread in each warp writes to shared memory
203+ if (lane_id == 0 ) {
204+ shared_sum[warp_id] = thread_sum;
205+ }
206+ __syncthreads ();
207+
208+ // First warp reduces the warp sums
209+ float sum_sq = 0 .0f ;
210+ if (threadIdx .x < 32 ) {
211+ int num_warps = (blockDim .x + 31 ) / 32 ;
212+ sum_sq = (threadIdx .x < num_warps) ? shared_sum[threadIdx .x ] : 0 .0f ;
213+
214+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
215+ sum_sq += __shfl_down_sync (0xffffffff , sum_sq, offset);
216+ }
217+ }
218+
219+ // Broadcast RMS to all threads
220+ __shared__ float rms_shared;
221+ if (threadIdx .x == 0 ) {
222+ float mean_sq = sum_sq / hidden_dim;
223+ rms_shared = rsqrtf (mean_sq + eps); // 1 / sqrt(mean_sq + eps)
224+ }
225+ __syncthreads ();
226+
227+ float rms_inv = rms_shared;
228+
229+ // Step 2: Find max absolute value for output quantization
230+ __shared__ float shared_max[32 ];
231+
232+ float thread_max = 0 .0f ;
233+ float normalized = val * rms_inv * w;
234+ thread_max = fmaxf (thread_max, fabsf (normalized));
235+
236+ // Reduce to find max
237+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
238+ thread_max = fmaxf (thread_max, __shfl_down_sync (0xffffffff , thread_max, offset));
239+ }
240+
241+ if (lane_id == 0 ) {
242+ shared_max[warp_id] = thread_max;
243+ }
244+ __syncthreads ();
245+
246+ float abs_max = 0 .0f ;
247+ if (threadIdx .x < 32 ) {
248+ int num_warps = (blockDim .x + 31 ) / 32 ;
249+ abs_max = (threadIdx .x < num_warps) ? shared_max[threadIdx .x ] : 0 .0f ;
250+
251+ for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
252+ abs_max = fmaxf (abs_max, __shfl_down_sync (0xffffffff , abs_max, offset));
253+ }
254+ }
255+
256+ // write to quant scale
257+ __shared__ float quant_scale_shared;
258+ if (threadIdx .x == 0 ) {
259+ quant_scale_shared = (abs_max > 0 .0f ) ? (127 .0f / abs_max) : 1 .0f ;
260+ *res_scale = 1 .0f / quant_scale_shared;
261+ }
262+ __syncthreads ();
263+
264+
265+ // clamp and write quantized norm
266+ float quant_scale = quant_scale_shared;
267+ int quantized = __float2int_rn (normalized * quant_scale);
268+ quantized = max (-127 , min (127 , quantized));
269+ res[idx] = (int8_t )quantized;
270+ }
271+
272+ #ifdef __cplusplus
273+ }
274+ #endif
275+
276+ #endif // NN_CUDA_CUH
0 commit comments