Skip to content

Commit f47edb8

Browse files
authored
ggml-cuda: check for srcs outside the cgraph (ggml-org#18583)
* ggml-cuda: check for srcs outside the cgraph * review: use leafs instead
1 parent da143b9 commit f47edb8

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@ struct ggml_cuda_graph {
10651065
int number_consecutive_updates = 0;
10661066
bool cuda_graphs_enabled = false;
10671067
std::vector<ggml_graph_node_properties> ggml_graph_properties;
1068+
std::vector<ggml_graph_node_properties> extraneous_srcs_properties;
10681069
#endif
10691070
};
10701071

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,15 +2973,16 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
29732973
}
29742974

29752975
// Check if the graph size has changed
2976-
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2976+
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
29772977
cuda_graph_update_required = true;
2978-
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2978+
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs);
29792979
}
29802980

29812981
// Loop over nodes in GGML graph to determine if CUDA graph update is required
29822982
// and store properties to allow this comparison for the next token
29832983
for (int i = 0; i < cgraph->n_nodes; i++) {
29842984
bool has_matching_properties = true;
2985+
29852986
if (!cuda_graph_update_required) {
29862987
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
29872988
}
@@ -2991,6 +2992,17 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
29912992
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
29922993
}
29932994

2995+
for (int i = 0; i < cgraph->n_leafs; i++) {
2996+
bool has_matching_properties = true;
2997+
if (!cuda_graph_update_required) {
2998+
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
2999+
}
3000+
if (!has_matching_properties) {
3001+
cuda_graph_update_required = true;
3002+
}
3003+
set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
3004+
}
3005+
29943006
return cuda_graph_update_required;
29953007
}
29963008

0 commit comments

Comments
 (0)