|
1 | 1 | #include "common.cuh" |
| 2 | +#include "ggml-cuda/vendors/cuda.h" |
| 3 | +#include <cublas_api.h> |
2 | 4 | #include "ggml.h" |
3 | 5 | #include "solve_tri.cuh" |
4 | | - |
5 | | -#define MAX_N_FAST 64 |
6 | | -#define MAX_K_FAST 32 |
7 | | - |
8 | | -// ====================== |
9 | | -// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction |
10 | | -// ====================== |
11 | | -// When ncols_template == 0 the bounds for the loops in this function are not |
12 | | -// known and can't be unrolled. As we want to keep pragma unroll for all other |
13 | | -// cases we supress the clang transformation warning here. |
14 | | -#ifdef __clang__ |
15 | | -# pragma clang diagnostic push |
16 | | -# pragma clang diagnostic ignored "-Wpass-failed" |
17 | | -#endif // __clang__ |
18 | | -template <int n_template, int k_template> |
19 | | -static __global__ void solve_tri_f32_fast(const float * __restrict__ A, |
20 | | - const float * __restrict__ B, |
21 | | - float * __restrict__ X, |
22 | | - const uint3 ne02, |
23 | | - const size_t nb02, |
24 | | - const size_t nb03, |
25 | | - const size_t nb12, |
26 | | - const size_t nb13, |
27 | | - const size_t nb2, |
28 | | - const size_t nb3, |
29 | | - const int n_arg, |
30 | | - const int k_arg) { |
31 | | - const int n = n_template == 0 ? n_arg : n_template; |
32 | | - const int k = k_template == 0 ? k_arg : k_template; |
33 | | - |
34 | | - const int batch_idx = blockIdx.x; |
35 | | - const int lane = threadIdx.x; |
36 | | - const int col_idx = threadIdx.y; |
37 | | - |
38 | | - if (col_idx >= k) { |
| 6 | +#include <cublas_v2.h> |
| 7 | +#include <cuda_runtime_api.h> |
| 8 | +#include <driver_types.h> |
| 9 | + |
| 10 | +static __global__ void get_batch_pointers(const float * A, float * X, const float ** A_ptrs, float ** X_ptrs, |
| 11 | + int64_t ne02, int64_t total_batches, |
| 12 | + size_t s02, size_t s03, size_t s2, size_t s3) { |
| 13 | + const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 14 | + if (idx >= total_batches) { |
39 | 15 | return; |
40 | 16 | } |
41 | 17 |
|
42 | | - const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); |
43 | | - const int64_t i02 = i02_i03.y; |
44 | | - const int64_t i03 = i02_i03.x; |
| 18 | + const int64_t i3 = idx / ne02; |
| 19 | + const int64_t i2 = idx % ne02; |
45 | 20 |
|
46 | | - const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); |
47 | | - const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); |
48 | | - float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); |
49 | | - |
50 | | - __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; |
51 | | - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; |
52 | | - |
53 | | - const int offset = threadIdx.x + threadIdx.y * blockDim.x; |
| 21 | + A_ptrs[idx] = A + i3 * s03 + i2 * s02; |
| 22 | + X_ptrs[idx] = X + i3 * s3 + i2 * s2; |
| 23 | +} |
54 | 24 |
|
55 | | -#pragma unroll |
56 | | - for (int i = 0; i < n * n; i += k * WARP_SIZE) { |
57 | | - int i0 = i + offset; |
58 | | - if (i0 < n * n) { |
59 | | - sA[i0] = A_batch[i0]; |
60 | | - } |
| 25 | +static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx, |
| 26 | + const float * A, |
| 27 | + const float * B, |
| 28 | + float * X, |
| 29 | + int n, |
| 30 | + int k, |
| 31 | + int64_t ne02, |
| 32 | + int64_t ne03, |
| 33 | + size_t s02, |
| 34 | + size_t s03, |
| 35 | + size_t s12, |
| 36 | + size_t s13, |
| 37 | + size_t s2, |
| 38 | + size_t s3, |
| 39 | + cudaStream_t stream) { |
| 40 | + const float alpha = 1.0f; |
| 41 | + const int64_t total_batches = ne02 * ne03; |
| 42 | + if (total_batches == 0) { |
| 43 | + return; |
61 | 44 | } |
62 | 45 |
|
63 | | - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; |
64 | | - |
65 | | -#pragma unroll |
66 | | - for (int i = 0; i < rows_per_warp; i++) { |
67 | | - const int i0 = lane + i * WARP_SIZE; |
68 | | - if (i0 < n) { |
69 | | - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; |
70 | | - } |
| 46 | + // Bulk copy B -> X (contiguous tensors) |
| 47 | + if (X != B) { |
| 48 | + const int64_t total_elements_BX = n * k * total_batches; |
| 49 | + CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), |
| 50 | + cudaMemcpyDeviceToDevice, stream)); |
71 | 51 | } |
72 | 52 |
|
73 | | - __syncthreads(); |
74 | | - |
75 | | -#pragma unroll |
76 | | - for (int row = 0; row < n; ++row) { |
77 | | - float sum = 0.0f; |
| 53 | + int id = ggml_cuda_get_device(); |
78 | 54 |
|
79 | | - { |
80 | | - int j = lane; |
81 | | - if (j < row) { |
82 | | - sum += sA[row * n + j] * sXt[col_idx * n + j]; |
83 | | - } |
84 | | - } |
85 | | - if (row >= WARP_SIZE) { |
86 | | - int j = WARP_SIZE + lane; |
87 | | - if (j < row) { |
88 | | - sum += sA[row * n + j] * sXt[col_idx * n + j]; |
89 | | - } |
90 | | - } |
| 55 | + ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches); |
| 56 | + ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches); |
91 | 57 |
|
92 | | - sum = warp_reduce_sum(sum); |
| 58 | + const float ** A_ptrs_dev = A_ptrs_alloc.get(); |
| 59 | + float ** X_ptrs_dev = X_ptrs_alloc.get(); |
93 | 60 |
|
94 | | - if (lane == 0) { |
95 | | - const float b_val = sXt[col_idx * n + row]; |
96 | | - const float a_diag = sA[row * n + row]; |
97 | | - // no safeguards for division by zero because that indicates corrupt |
98 | | - // data anyway |
99 | | - sXt[col_idx * n + row] = (b_val - sum) / a_diag; |
100 | | - } |
101 | | - } |
| 61 | + get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>( |
| 62 | + A, X, A_ptrs_dev, X_ptrs_dev, ne02, total_batches, s02, s03, s2, s3); |
102 | 63 |
|
103 | | - __syncthreads(); |
| 64 | + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); |
104 | 65 |
|
105 | | -#pragma unroll |
106 | | - for (int i = 0; i < rows_per_warp; i++) { |
107 | | - const int i0 = lane + i * WARP_SIZE; |
108 | | - if (i0 < n) { |
109 | | - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; |
110 | | - } |
111 | | - } |
112 | | -} |
113 | | -#ifdef __clang__ |
114 | | -# pragma clang diagnostic pop |
115 | | -#endif // __clang__ |
| 66 | + // Yes, this is necessary, without this we get RMSE errors |
| 67 | + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH)); |
| 68 | + CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), |
| 69 | + CUBLAS_SIDE_RIGHT, |
| 70 | + CUBLAS_FILL_MODE_UPPER, |
| 71 | + CUBLAS_OP_N, |
| 72 | + CUBLAS_DIAG_NON_UNIT, |
| 73 | + k, |
| 74 | + n, |
| 75 | + &alpha, |
| 76 | + A_ptrs_dev, n, |
| 77 | + X_ptrs_dev, k, |
| 78 | + total_batches)); |
116 | 79 |
|
117 | | -// ====================== |
118 | | -// General Kernel for larger matrices |
119 | | -// Uses a simpler approach with fixed tile size |
120 | | -// ====================== |
121 | | -#define GENERAL_TILE_SIZE 32 |
| 80 | + // revert to standard mode from common.cuh |
| 81 | + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH)); |
122 | 82 |
|
123 | | -template <int n_template, int k_template> |
124 | | -static __global__ void solve_tri_f32_general(const float * __restrict__ A, |
125 | | - const float * __restrict__ B, |
126 | | - float * __restrict__ X, |
127 | | - const uint3 ne02, |
128 | | - const size_t nb02, |
129 | | - const size_t nb03, |
130 | | - const size_t nb12, |
131 | | - const size_t nb13, |
132 | | - const size_t nb2, |
133 | | - const size_t nb3, |
134 | | - const int n_arg, |
135 | | - const int k_arg) { |
136 | | - const int n = n_template == 0 ? n_arg : n_template; |
137 | | - const int k = k_template == 0 ? k_arg : k_template; |
138 | | - |
139 | | - const int batch_idx = blockIdx.x; |
140 | | - const int col_idx = blockIdx.y; |
141 | | - const int tid = threadIdx.x; |
142 | | - |
143 | | - if (col_idx >= k) { |
144 | | - return; |
145 | | - } |
146 | | - const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); |
147 | | - const int64_t i02 = i02_i03.y; |
148 | | - const int64_t i03 = i02_i03.x; |
149 | | - |
150 | | - const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); |
151 | | - const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); |
152 | | - float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); |
153 | | - |
154 | | - // Shared memory for current tile |
155 | | - __shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE]; |
156 | | - __shared__ float sB[GENERAL_TILE_SIZE]; |
157 | | - __shared__ float sX[GENERAL_TILE_SIZE]; |
158 | | - |
159 | | - // Process in tiles |
160 | | - for (int tile_start = 0; tile_start < n; tile_start += GENERAL_TILE_SIZE) { |
161 | | - int tile_end = min(tile_start + GENERAL_TILE_SIZE, n); |
162 | | - int tile_n = tile_end - tile_start; |
163 | | - // Load tile of A matrix |
164 | | - for (int i = tid; i < tile_n * tile_n; i += blockDim.x) { |
165 | | - int local_row = i / tile_n; |
166 | | - int local_col = i % tile_n; |
167 | | - int global_row = tile_start + local_row; |
168 | | - int global_col = tile_start + local_col; |
169 | | - if (global_col <= global_row) { |
170 | | - sA[local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col]; |
171 | | - } else { |
172 | | - sA[local_row * GENERAL_TILE_SIZE + local_col] = 0.0f; |
173 | | - } |
174 | | - } |
175 | | - __syncthreads(); |
176 | | - // Load corresponding part of B and initialize X |
177 | | - if (tid < tile_n) { |
178 | | - sB[tid] = B_batch[(tile_start + tid) * k + col_idx]; |
179 | | - sX[tid] = sB[tid]; |
180 | | - } |
181 | | - __syncthreads(); |
182 | | - // Forward substitution for this tile |
183 | | - for (int row = 0; row < tile_n; ++row) { |
184 | | - if (tid == row) { |
185 | | - float sum = 0.0f; |
186 | | - // Sum contributions from previous rows in this tile |
187 | | - for (int j = 0; j < row; ++j) { |
188 | | - sum += sA[row * GENERAL_TILE_SIZE + j] * sX[j]; |
189 | | - } |
190 | | - // Sum contributions from previous tiles |
191 | | - if (tile_start > 0) { |
192 | | - int global_row = tile_start + row; |
193 | | - for (int j = 0; j < tile_start; ++j) { |
194 | | - sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx]; |
195 | | - } |
196 | | - } |
197 | | - const float a_diag = sA[row * GENERAL_TILE_SIZE + row]; |
198 | | - sX[row] = (sB[row] - sum) / a_diag; |
199 | | - } |
200 | | - __syncthreads(); |
201 | | - } |
202 | | - // Store results back to global memory |
203 | | - if (tid < tile_n) { |
204 | | - int global_row = tile_start + tid; |
205 | | - X_batch[global_row * k + col_idx] = sX[tid]; |
206 | | - } |
207 | | - __syncthreads(); |
208 | | - } |
209 | | -} |
210 | | -static void solve_tri_f32_cuda(const float * A, |
211 | | - const float * B, |
212 | | - float * X, |
213 | | - int n, |
214 | | - int k, |
215 | | - int64_t ne02, |
216 | | - int64_t ne03, |
217 | | - size_t nb02, |
218 | | - size_t nb03, |
219 | | - size_t nb12, |
220 | | - size_t nb13, |
221 | | - size_t nb2, |
222 | | - size_t nb3, |
223 | | - cudaStream_t stream) { |
224 | | - const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); |
225 | | - // Choose kernel based on matrix size |
226 | | - if (n <= MAX_N_FAST && k <= MAX_K_FAST) { |
227 | | - // Use fast kernel for small matrices |
228 | | - dim3 threads(WARP_SIZE, k); |
229 | | - dim3 grid(ne02 * ne03); |
230 | | - if (n == 64) { |
231 | | - switch (k) { |
232 | | - case 32: |
233 | | - solve_tri_f32_fast<64, 32> |
234 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
235 | | - break; |
236 | | - case 16: |
237 | | - solve_tri_f32_fast<64, 16> |
238 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
239 | | - break; |
240 | | - case 14: |
241 | | - solve_tri_f32_fast<64, 14> |
242 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
243 | | - break; |
244 | | - case 12: |
245 | | - solve_tri_f32_fast<64, 12> |
246 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
247 | | - break; |
248 | | - case 10: |
249 | | - solve_tri_f32_fast<64, 10> |
250 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
251 | | - break; |
252 | | - case 8: |
253 | | - solve_tri_f32_fast<64, 8> |
254 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
255 | | - break; |
256 | | - case 6: |
257 | | - solve_tri_f32_fast<64, 6> |
258 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
259 | | - break; |
260 | | - case 4: |
261 | | - solve_tri_f32_fast<64, 4> |
262 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
263 | | - break; |
264 | | - case 2: |
265 | | - solve_tri_f32_fast<64, 2> |
266 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
267 | | - break; |
268 | | - case 1: |
269 | | - solve_tri_f32_fast<64, 1> |
270 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); |
271 | | - break; |
272 | | - default: |
273 | | - solve_tri_f32_fast<0, 0> |
274 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); |
275 | | - } |
276 | | - } else { // run general case |
277 | | - solve_tri_f32_fast<0, 0> |
278 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); |
279 | | - } |
280 | | - } else { |
281 | | - // Use general kernel for larger matrices |
282 | | - dim3 threads(256, 1); // 256 threads per block |
283 | | - dim3 grid(ne02 * ne03, k); // One block per column |
284 | | - solve_tri_f32_general<0, 0> |
285 | | - <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); |
286 | | - } |
| 83 | + GGML_UNUSED_VARS(s12, s13); |
287 | 84 | } |
288 | 85 |
|
| 86 | + |
| 87 | +// ---------------------------------------------------------------------------- |
| 88 | +// Public entry point |
| 89 | +// ---------------------------------------------------------------------------- |
289 | 90 | void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
290 | | - const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) |
291 | | - const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) |
| 91 | + const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular) |
| 92 | + const ggml_tensor * src1 = dst->src[1]; // B (n×k) |
292 | 93 |
|
293 | 94 | ggml_is_contiguous(src0); |
294 | 95 | ggml_is_contiguous(src1); |
295 | 96 |
|
296 | 97 | const int64_t n = src0->ne[0]; |
297 | 98 | const int64_t k = src1->ne[0]; |
298 | | - |
299 | | - solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], |
300 | | - src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), |
301 | | - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), |
302 | | - dst->nb[3] / sizeof(float), ctx.stream()); |
| 99 | + const int64_t ne02 = src0->ne[2]; |
| 100 | + const int64_t ne03 = src0->ne[3]; |
| 101 | + |
| 102 | + solve_tri_f32_cublas(ctx, |
| 103 | + (const float *) src0->data, |
| 104 | + (const float *) src1->data, |
| 105 | + (float *) dst->data, |
| 106 | + n, k, |
| 107 | + ne02, ne03, |
| 108 | + src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), |
| 109 | + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), |
| 110 | + dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float), |
| 111 | + ctx.stream()); |
303 | 112 | } |
0 commit comments