Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Dec 5, 2025

Extend SOLVE_TRI with a tiled kernel to work on all dimensions. Asked my LLM friends to generate a tiled kernel, added some backend tests for it. Not really too worried about optimizations for now since the base case is 64, but adding this in case we want to experiment with different / configurable chunk size for hybrid models with chunking (like Qwen3Next or the hopefully-coming Kimi Linear).

@pwilkin pwilkin requested a review from ggerganov as a code owner December 5, 2025 12:14
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 5, 2025
@am17an
Copy link
Collaborator

am17an commented Dec 6, 2025

Does this supersede #17703?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 6, 2025

@am17an nope, that other one is optimization for the small kernel.

Copy link
Collaborator

@am17an am17an left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this kernel have similar performance to the templated one? If there is a not a huge performance difference it is much preferred to keep one kernel for future maintenance + binary size (a big difference in e2e performance of qwen3-next would justify its existence however)

float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);

// Shared memory for current tile
__shared__ float sA[GENERAL_TILE_SIZE * GENERAL_TILE_SIZE];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably stuff from the same bank conflict problem right?

Comment on lines 161 to 169
int tile_end = min(tile_start + GENERAL_TILE_SIZE, n);
int tile_n = tile_end - tile_start;
// Load tile of A matrix
for (int i = tid; i < tile_n * tile_n; i += blockDim.x) {
int local_row = i / tile_n;
int local_col = i % tile_n;
int global_row = tile_start + local_row;
int global_col = tile_start + local_col;
if (global_col <= global_row) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const wherever

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copypasting and modifying the code like this is not acceptable in terms of mainability. Instead modify the existing kernel to handle tiling for the case n_template == 0 && k_template == 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you added this unintentionally, please remove it from this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It adds itself automatically due to a GitHub hook installed by the server development code 😐

@JohannesGaessler
Copy link
Collaborator

Also be aware that this PR will need to be rebased on top of master.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 8, 2025

I have a reworked approach using cuBlas trsmBatched for the big sizes instead, but I haven't been able to figure out some weird race conditions...

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 8, 2025

Okay, @JohannesGaessler @am17an question:

I rewrote the entire kernel to just use cublasStrsmBatched. For the small case, it is faster than my previous kernel but slower than the new kernel in master, but in general the numbers are like this:

Optimized kernel in master:

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]):                   49140 runs -    21.26 us/run -   2.13 MFLOP/run - 100.18 GFLOPS

cuBLAS:

SOLVE_TRI(type=f32,ne_lhs=[64,64,4,4],ne_rhs=[32,64,4,4]):                   40950 runs -    28.86 us/run -   2.13 MFLOP/run -  73.80 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]):                16380 runs -    83.70 us/run -   4.23 MFLOP/run -  50.50 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[64,64,8,32],ne_rhs=[64,64,8,32]):                 10276 runs -   106.43 us/run -  68.16 MFLOP/run - 640.37 GFLOPS
SOLVE_TRI(type=f32,ne_lhs=[128,128,4,32],ne_rhs=[128,128,4,32]):              4070 runs -   255.32 us/run - 270.53 MFLOP/run -   1.06 TFLOPS
SOLVE_TRI(type=f32,ne_lhs=[256,256,4,2],ne_rhs=[128,256,4,2]):                7425 runs -   151.09 us/run -  67.37 MFLOP/run - 445.91 GFLOPS

Does it make sense to keep the optimized kernel for the small cases with the ~25% performance increase or should I just replace the entire kernel with the cuBLAS version?

@JohannesGaessler
Copy link
Collaborator

Is there an equivalent solution for HIP?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 8, 2025

@JohannesGaessler from what I've checked there's both hipblasStrsm and hipblasStrsmBatched and there should be mappings for both.

@JohannesGaessler
Copy link
Collaborator

My opinion is that both a manually written kernel and a library call would both be fine. I don't think that performance numbers for test-backend-ops matter for this consideration, what matters is whether there is a meaningful difference in end-to-end performance with any actual models. 25% of 1% of the total runtime would be negligible.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 8, 2025

Okay, I merged both versions then, I think we can keep it.

SOLVE_TRI is pretty mission-critical for tuning the new generation of hybrid models with recurrent logic, since its efficiency determines the chunk size that can be used (and thus potentially the graph size). It's only used now in Qwen3Next but probably Kimi Linear is next in line.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs server testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants