Skip to content

Commit e34f042

Browse files
authored
CUDA: fuse muls (#21665)
1 parent d132f22 commit e34f042

3 files changed

Lines changed: 42 additions & 7 deletions

File tree

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
472472
}
473473
}
474474

475+
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
476+
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
477+
478+
switch (n_fuse) {
479+
case 2:
480+
ggml_cuda_op_fused_binbcast_impl<op_mul, 2>(ctx, dst);
481+
break;
482+
case 3:
483+
ggml_cuda_op_fused_binbcast_impl<op_mul, 3>(ctx, dst);
484+
break;
485+
case 4:
486+
ggml_cuda_op_fused_binbcast_impl<op_mul, 4>(ctx, dst);
487+
break;
488+
case 5:
489+
ggml_cuda_op_fused_binbcast_impl<op_mul, 5>(ctx, dst);
490+
break;
491+
case 6:
492+
ggml_cuda_op_fused_binbcast_impl<op_mul, 6>(ctx, dst);
493+
break;
494+
case 7:
495+
ggml_cuda_op_fused_binbcast_impl<op_mul, 7>(ctx, dst);
496+
break;
497+
case 8:
498+
ggml_cuda_op_fused_binbcast_impl<op_mul, 8>(ctx, dst);
499+
break;
500+
default:
501+
GGML_ASSERT(false && "Unsupported n_fuse value");
502+
}
503+
}
504+
475505
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
476506
const ggml_tensor * src0 = dst->src[0];
477507

ggml/src/ggml-cuda/binbcast.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
99
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1010

1111
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
12+
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37583758
continue;
37593759
}
37603760

3761-
if (node->op == GGML_OP_ADD) {
3761+
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
37623762
int n_fuse = 0;
37633763
ggml_op ops[8];
3764-
std::fill(ops, ops + 8, GGML_OP_ADD);
3764+
std::fill(ops, ops + 8, node->op);
37653765

37663766
for (; n_fuse <= 6; ++n_fuse){
37673767
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
@@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37783778
n_fuse++;
37793779

37803780
if (n_fuse > 1) {
3781-
ggml_tensor fused_add_node;
3782-
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
3781+
ggml_tensor fused_node;
3782+
memcpy(&fused_node, node, sizeof(ggml_tensor));
37833783
for (int j = 0; j < n_fuse - 1; ++j) {
3784-
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3784+
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3785+
}
3786+
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
3787+
if (node->op == GGML_OP_ADD) {
3788+
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
3789+
} else {
3790+
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
37853791
}
3786-
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
3787-
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
37883792
i += n_fuse - 1;
37893793

37903794
continue;

0 commit comments

Comments
 (0)