@@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37583758 continue ;
37593759 }
37603760
3761- if (node->op == GGML_OP_ADD) {
3761+ if (node->op == GGML_OP_ADD || node-> op == GGML_OP_MUL ) {
37623762 int n_fuse = 0 ;
37633763 ggml_op ops[8 ];
3764- std::fill (ops, ops + 8 , GGML_OP_ADD );
3764+ std::fill (ops, ops + 8 , node-> op );
37653765
37663766 for (; n_fuse <= 6 ; ++n_fuse){
37673767 if (!ggml_can_fuse (cgraph, i + n_fuse, ops + n_fuse, 2 )) {
@@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37783778 n_fuse++;
37793779
37803780 if (n_fuse > 1 ) {
3781- ggml_tensor fused_add_node ;
3782- memcpy (&fused_add_node , node, sizeof (ggml_tensor));
3781+ ggml_tensor fused_node ;
3782+ memcpy (&fused_node , node, sizeof (ggml_tensor));
37833783 for (int j = 0 ; j < n_fuse - 1 ; ++j) {
3784- fused_add_node.src [j + 2 ] = cgraph->nodes [i + j + 1 ]->src [1 ];
3784+ fused_node.src [j + 2 ] = cgraph->nodes [i + j + 1 ]->src [1 ];
3785+ }
3786+ fused_node.data = cgraph->nodes [i + n_fuse - 1 ]->data ;
3787+ if (node->op == GGML_OP_ADD) {
3788+ ggml_cuda_op_fused_add (*cuda_ctx, &fused_node, n_fuse);
3789+ } else {
3790+ ggml_cuda_op_fused_mul (*cuda_ctx, &fused_node, n_fuse);
37853791 }
3786- fused_add_node.data = cgraph->nodes [i + n_fuse - 1 ]->data ;
3787- ggml_cuda_op_fused_add (*cuda_ctx, &fused_add_node, n_fuse);
37883792 i += n_fuse - 1 ;
37893793
37903794 continue ;
0 commit comments