Skip to content

Commit c8abbe0

Browse files
committed
Just use cuBLAS for everything...
1 parent 5efca20 commit c8abbe0

File tree

3 files changed

+103
-278
lines changed

3 files changed

+103
-278
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4619,9 +4619,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46194619
case GGML_OP_OPT_STEP_SGD:
46204620
case GGML_OP_CUMSUM:
46214621
case GGML_OP_TRI:
4622-
return true;
46234622
case GGML_OP_SOLVE_TRI:
46244623
return true;
4624+
46254625
default:
46264626
return false;
46274627
}

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 83 additions & 274 deletions
Original file line numberDiff line numberDiff line change
@@ -1,303 +1,112 @@
11
#include "common.cuh"
2+
#include "ggml-cuda/vendors/cuda.h"
3+
#include <cublas_api.h>
24
#include "ggml.h"
35
#include "solve_tri.cuh"
4-
5-
#define MAX_N_FAST 64
6-
#define MAX_K_FAST 32
7-
8-
// ======================
9-
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
10-
// ======================
11-
// When ncols_template == 0 the bounds for the loops in this function are not
12-
// known and can't be unrolled. As we want to keep pragma unroll for all other
13-
// cases we supress the clang transformation warning here.
14-
#ifdef __clang__
15-
# pragma clang diagnostic push
16-
# pragma clang diagnostic ignored "-Wpass-failed"
17-
#endif // __clang__
18-
template <int n_template, int k_template>
19-
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
20-
const float * __restrict__ B,
21-
float * __restrict__ X,
22-
const uint3 ne02,
23-
const size_t nb02,
24-
const size_t nb03,
25-
const size_t nb12,
26-
const size_t nb13,
27-
const size_t nb2,
28-
const size_t nb3,
29-
const int n_arg,
30-
const int k_arg) {
31-
const int n = n_template == 0 ? n_arg : n_template;
32-
const int k = k_template == 0 ? k_arg : k_template;
33-
34-
const int batch_idx = blockIdx.x;
35-
const int lane = threadIdx.x;
36-
const int col_idx = threadIdx.y;
37-
38-
if (col_idx >= k) {
6+
#include <cublas_v2.h>
7+
#include <cuda_runtime_api.h>
8+
#include <driver_types.h>
9+
10+
static __global__ void get_batch_pointers(const float * A, float * X, const float ** A_ptrs, float ** X_ptrs,
11+
int64_t ne02, int64_t total_batches,
12+
size_t s02, size_t s03, size_t s2, size_t s3) {
13+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
14+
if (idx >= total_batches) {
3915
return;
4016
}
4117

42-
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
43-
const int64_t i02 = i02_i03.y;
44-
const int64_t i03 = i02_i03.x;
18+
const int64_t i3 = idx / ne02;
19+
const int64_t i2 = idx % ne02;
4520

46-
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
47-
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
48-
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
49-
50-
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
51-
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
52-
53-
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
21+
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
22+
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
23+
}
5424

55-
#pragma unroll
56-
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
57-
int i0 = i + offset;
58-
if (i0 < n * n) {
59-
sA[i0] = A_batch[i0];
60-
}
25+
static void solve_tri_f32_cublas(ggml_backend_cuda_context &ctx,
26+
const float * A,
27+
const float * B,
28+
float * X,
29+
int n,
30+
int k,
31+
int64_t ne02,
32+
int64_t ne03,
33+
size_t s02,
34+
size_t s03,
35+
size_t s12,
36+
size_t s13,
37+
size_t s2,
38+
size_t s3,
39+
cudaStream_t stream) {
40+
const float alpha = 1.0f;
41+
const int64_t total_batches = ne02 * ne03;
42+
if (total_batches == 0) {
43+
return;
6144
}
6245

63-
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
64-
65-
#pragma unroll
66-
for (int i = 0; i < rows_per_warp; i++) {
67-
const int i0 = lane + i * WARP_SIZE;
68-
if (i0 < n) {
69-
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
70-
}
46+
// Bulk copy B -> X (contiguous tensors)
47+
if (X != B) {
48+
const int64_t total_elements_BX = n * k * total_batches;
49+
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float),
50+
cudaMemcpyDeviceToDevice, stream));
7151
}
7252

73-
__syncthreads();
74-
75-
#pragma unroll
76-
for (int row = 0; row < n; ++row) {
77-
float sum = 0.0f;
53+
int id = ggml_cuda_get_device();
7854

79-
{
80-
int j = lane;
81-
if (j < row) {
82-
sum += sA[row * n + j] * sXt[col_idx * n + j];
83-
}
84-
}
85-
if (row >= WARP_SIZE) {
86-
int j = WARP_SIZE + lane;
87-
if (j < row) {
88-
sum += sA[row * n + j] * sXt[col_idx * n + j];
89-
}
90-
}
55+
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
56+
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
9157

92-
sum = warp_reduce_sum(sum);
58+
const float ** A_ptrs_dev = A_ptrs_alloc.get();
59+
float ** X_ptrs_dev = X_ptrs_alloc.get();
9360

94-
if (lane == 0) {
95-
const float b_val = sXt[col_idx * n + row];
96-
const float a_diag = sA[row * n + row];
97-
// no safeguards for division by zero because that indicates corrupt
98-
// data anyway
99-
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
100-
}
101-
}
61+
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(
62+
A, X, A_ptrs_dev, X_ptrs_dev, ne02, total_batches, s02, s03, s2, s3);
10263

103-
__syncthreads();
64+
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
10465

105-
#pragma unroll
106-
for (int i = 0; i < rows_per_warp; i++) {
107-
const int i0 = lane + i * WARP_SIZE;
108-
if (i0 < n) {
109-
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
110-
}
111-
}
112-
}
113-
#ifdef __clang__
114-
# pragma clang diagnostic pop
115-
#endif // __clang__
66+
// Yes, this is necessary, without this we get RMSE errors
67+
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
68+
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id),
69+
CUBLAS_SIDE_RIGHT,
70+
CUBLAS_FILL_MODE_UPPER,
71+
CUBLAS_OP_N,
72+
CUBLAS_DIAG_NON_UNIT,
73+
k,
74+
n,
75+
&alpha,
76+
A_ptrs_dev, n,
77+
X_ptrs_dev, k,
78+
total_batches));
11679

117-
// ======================
118-
// General Kernel for larger matrices
119-
// Uses a simpler approach with fixed tile size
120-
// ======================
121-
#define GENERAL_TILE_SIZE 32
80+
// revert to standard mode from common.cuh
81+
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
12282

123-
template <int n_template, int k_template>
124-
static __global__ void solve_tri_f32_general(const float * __restrict__ A,
125-
const float * __restrict__ B,
126-
float * __restrict__ X,
127-
const uint3 ne02,
128-
const size_t nb02,
129-
const size_t nb03,
130-
const size_t nb12,
131-
const size_t nb13,
132-
const size_t nb2,
133-
const size_t nb3,
134-
const int n_arg,
135-
const int k_arg) {
136-
const int n = n_template == 0 ? n_arg : n_template;
137-
const int k = k_template == 0 ? k_arg : k_template;
138-
139-
const int batch_idx = blockIdx.x;
140-
const int col_idx = blockIdx.y;
141-
const int tid = threadIdx.x;
142-
143-
if (col_idx >= k) {
144-
return;
145-
}
146-
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
147-
const int64_t i02 = i02_i03.y;
148-
const int64_t i03 = i02_i03.x;
149-
150-
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
151-
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
152-
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
153-
154-
// Shared memory for current tile
155-
__shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
156-
__shared__ float sB[GENERAL_TILE_SIZE];
157-
__shared__ float sX[GENERAL_TILE_SIZE];
158-
159-
// Process in tiles
160-
for (int tile_start = 0; tile_start < n; tile_start += GENERAL_TILE_SIZE) {
161-
int tile_end = min(tile_start + GENERAL_TILE_SIZE, n);
162-
int tile_n = tile_end - tile_start;
163-
// Load tile of A matrix
164-
for (int i = tid; i < tile_n * tile_n; i += blockDim.x) {
165-
int local_row = i / tile_n;
166-
int local_col = i % tile_n;
167-
int global_row = tile_start + local_row;
168-
int global_col = tile_start + local_col;
169-
if (global_col <= global_row) {
170-
sA[local_row * GENERAL_TILE_SIZE + local_col] = A_batch[global_row * n + global_col];
171-
} else {
172-
sA[local_row * GENERAL_TILE_SIZE + local_col] = 0.0f;
173-
}
174-
}
175-
__syncthreads();
176-
// Load corresponding part of B and initialize X
177-
if (tid < tile_n) {
178-
sB[tid] = B_batch[(tile_start + tid) * k + col_idx];
179-
sX[tid] = sB[tid];
180-
}
181-
__syncthreads();
182-
// Forward substitution for this tile
183-
for (int row = 0; row < tile_n; ++row) {
184-
if (tid == row) {
185-
float sum = 0.0f;
186-
// Sum contributions from previous rows in this tile
187-
for (int j = 0; j < row; ++j) {
188-
sum += sA[row * GENERAL_TILE_SIZE + j] * sX[j];
189-
}
190-
// Sum contributions from previous tiles
191-
if (tile_start > 0) {
192-
int global_row = tile_start + row;
193-
for (int j = 0; j < tile_start; ++j) {
194-
sum += A_batch[global_row * n + j] * X_batch[j * k + col_idx];
195-
}
196-
}
197-
const float a_diag = sA[row * GENERAL_TILE_SIZE + row];
198-
sX[row] = (sB[row] - sum) / a_diag;
199-
}
200-
__syncthreads();
201-
}
202-
// Store results back to global memory
203-
if (tid < tile_n) {
204-
int global_row = tile_start + tid;
205-
X_batch[global_row * k + col_idx] = sX[tid];
206-
}
207-
__syncthreads();
208-
}
209-
}
210-
static void solve_tri_f32_cuda(const float * A,
211-
const float * B,
212-
float * X,
213-
int n,
214-
int k,
215-
int64_t ne02,
216-
int64_t ne03,
217-
size_t nb02,
218-
size_t nb03,
219-
size_t nb12,
220-
size_t nb13,
221-
size_t nb2,
222-
size_t nb3,
223-
cudaStream_t stream) {
224-
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
225-
// Choose kernel based on matrix size
226-
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
227-
// Use fast kernel for small matrices
228-
dim3 threads(WARP_SIZE, k);
229-
dim3 grid(ne02 * ne03);
230-
if (n == 64) {
231-
switch (k) {
232-
case 32:
233-
solve_tri_f32_fast<64, 32>
234-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
235-
break;
236-
case 16:
237-
solve_tri_f32_fast<64, 16>
238-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
239-
break;
240-
case 14:
241-
solve_tri_f32_fast<64, 14>
242-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
243-
break;
244-
case 12:
245-
solve_tri_f32_fast<64, 12>
246-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
247-
break;
248-
case 10:
249-
solve_tri_f32_fast<64, 10>
250-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
251-
break;
252-
case 8:
253-
solve_tri_f32_fast<64, 8>
254-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
255-
break;
256-
case 6:
257-
solve_tri_f32_fast<64, 6>
258-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
259-
break;
260-
case 4:
261-
solve_tri_f32_fast<64, 4>
262-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
263-
break;
264-
case 2:
265-
solve_tri_f32_fast<64, 2>
266-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
267-
break;
268-
case 1:
269-
solve_tri_f32_fast<64, 1>
270-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
271-
break;
272-
default:
273-
solve_tri_f32_fast<0, 0>
274-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
275-
}
276-
} else { // run general case
277-
solve_tri_f32_fast<0, 0>
278-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
279-
}
280-
} else {
281-
// Use general kernel for larger matrices
282-
dim3 threads(256, 1); // 256 threads per block
283-
dim3 grid(ne02 * ne03, k); // One block per column
284-
solve_tri_f32_general<0, 0>
285-
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
286-
}
83+
GGML_UNUSED_VARS(s12, s13);
28784
}
28885

86+
87+
// ----------------------------------------------------------------------------
88+
// Public entry point
89+
// ----------------------------------------------------------------------------
28990
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
290-
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
291-
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
91+
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
92+
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
29293

29394
ggml_is_contiguous(src0);
29495
ggml_is_contiguous(src1);
29596

29697
const int64_t n = src0->ne[0];
29798
const int64_t k = src1->ne[0];
298-
299-
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
300-
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
301-
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
302-
dst->nb[3] / sizeof(float), ctx.stream());
99+
const int64_t ne02 = src0->ne[2];
100+
const int64_t ne03 = src0->ne[3];
101+
102+
solve_tri_f32_cublas(ctx,
103+
(const float *) src0->data,
104+
(const float *) src1->data,
105+
(float *) dst->data,
106+
n, k,
107+
ne02, ne03,
108+
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
109+
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float),
110+
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
111+
ctx.stream());
303112
}

0 commit comments

Comments
 (0)