Skip to content

Commit ec16a07

Browse files
gaugarg-nvam17an
andauthored
Optimize MOE GEMV kernel for BS > 1. (ggml-org#20905)
* Optimize MOE GEMV kernel for BS > 1. The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row. New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync). This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization. * Remove em-dashes * Cherry-pick changes from @am17an PR ggml-org#20885 to enable small_k optimization only for cases where it benefits Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8 * Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype --------- Co-authored-by: Aman Gupta <amangupta052@gmail.com>
1 parent f5d1c41 commit ec16a07

3 files changed

Lines changed: 358 additions & 59 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23432343
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
23442344
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
23452345
if (ggml_is_quantized(src0->type)) {
2346-
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2346+
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
2347+
if (ne2 <= mmvq_mmid_max) {
23472348
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
23482349
return;
23492350
}
@@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
29462947
}
29472948

29482949
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2949-
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2950-
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2951-
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
2952-
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
2953-
use_cuda_graph = false;
2950+
if (node->op == GGML_OP_MUL_MAT_ID) {
2951+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2952+
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
2953+
if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
2954+
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2955+
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
2956+
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
2957+
use_cuda_graph = false;
29542958
#ifndef NDEBUG
2955-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2959+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
29562960
#endif
2961+
}
29572962
}
29582963

29592964
if (!use_cuda_graph) {

0 commit comments

Comments
 (0)