Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit e227551

Browse files
committed
warm-up for faster speed
1 parent ae30f12 commit e227551

File tree

2 files changed

+69
-19
lines changed

2 files changed

+69
-19
lines changed

controllers/llamaCPP.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <drogon/HttpTypes.h>
88
#include <regex>
99
#include <thread>
10+
#include <trantor/utils/Logger.h>
1011

1112
using namespace inferences;
1213

@@ -39,6 +40,49 @@ std::string create_return_json(const std::string &id, const std::string &model,
3940
return Json::writeString(writer, root);
4041
}
4142

43+
void llamaCPP::warmupModel() {
44+
auto lock = llama.lock();
45+
llama.rewind();
46+
llama_reset_timings(llama.ctx);
47+
48+
llama.prompt = "hello";
49+
llama.params.n_predict = 1;
50+
llama.loadPrompt();
51+
llama.beginCompletion();
52+
size_t stop_pos = std::string::npos;
53+
54+
while (llama.has_next_token) {
55+
const completion_token_output token_with_probs = llama.doCompletion();
56+
const std::string token_text =
57+
token_with_probs.tok == -1
58+
? ""
59+
: llama_token_to_piece(llama.ctx, token_with_probs.tok);
60+
61+
stop_pos = llama.findStoppingStrings(llama.generated_text,
62+
token_text.size(), STOP_FULL);
63+
}
64+
65+
if (stop_pos == std::string::npos) {
66+
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
67+
}
68+
if (stop_pos != std::string::npos) {
69+
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
70+
llama.generated_text.end());
71+
}
72+
auto probs = llama.generated_token_probs;
73+
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
74+
const std::vector<llama_token> stop_word_toks =
75+
llama_tokenize(llama.ctx, llama.stopping_word, false);
76+
probs = std::vector<completion_token_output>(
77+
llama.generated_token_probs.begin(),
78+
llama.generated_token_probs.end() - stop_word_toks.size());
79+
}
80+
81+
LOG_INFO << llama.generated_text;
82+
LOG_INFO << "Finish the warmup";
83+
return;
84+
}
85+
4286
void llamaCPP::chatCompletion(
4387
const HttpRequestPtr &req,
4488
std::function<void(const HttpResponsePtr &)> &&callback) {
@@ -297,5 +341,6 @@ void llamaCPP::loadModel(
297341
jsonResp["message"] = "Model loaded successfully";
298342
model_loaded = true;
299343
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
344+
warmupModel();
300345
callback(resp);
301346
}

controllers/llamaCPP.h

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -525,12 +525,12 @@ struct llama_server_context {
525525
if (llama_decode(ctx,
526526
llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) {
527527
LOG_ERROR_LLAMA("failed to eval",
528-
{
529-
{"n_eval", n_eval},
530-
{"n_past", n_past},
531-
{"embd",
532-
tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
533-
});
528+
{
529+
{"n_eval", n_eval},
530+
{"n_past", n_past},
531+
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past,
532+
embd.cend())},
533+
});
534534
has_next_token = false;
535535
return result;
536536
}
@@ -677,9 +677,9 @@ struct llama_server_context {
677677
static const int n_embd = llama_n_embd(model);
678678
if (!params.embedding) {
679679
LOG_WARNING_LLAMA("embedding disabled",
680-
{
681-
{"params.embedding", params.embedding},
682-
});
680+
{
681+
{"params.embedding", params.embedding},
682+
});
683683
return std::vector<float>(n_embd, 0.0f);
684684
}
685685
const float *data = llama_get_embeddings(ctx);
@@ -891,17 +891,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
891891
}
892892
}
893893
#else
894-
LOG_WARNING_LLAMA("llama.cpp was compiled without cuBLAS. It is not possible "
895-
"to set a tensor split.\n",
896-
{});
894+
LOG_WARNING_LLAMA(
895+
"llama.cpp was compiled without cuBLAS. It is not possible "
896+
"to set a tensor split.\n",
897+
{});
897898
#endif // GGML_USE_CUBLAS
898899
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
899900
#ifdef GGML_USE_CUBLAS
900901
params.mul_mat_q = false;
901902
#else
902-
LOG_WARNING_LLAMA("warning: llama.cpp was compiled without cuBLAS. Disabling "
903-
"mul_mat_q kernels has no effect.\n",
904-
{});
903+
LOG_WARNING_LLAMA(
904+
"warning: llama.cpp was compiled without cuBLAS. Disabling "
905+
"mul_mat_q kernels has no effect.\n",
906+
{});
905907
#endif // GGML_USE_CUBLAS
906908
} else if (arg == "--main-gpu" || arg == "-mg") {
907909
if (++i >= argc) {
@@ -911,9 +913,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
911913
#ifdef GGML_USE_CUBLAS
912914
params.main_gpu = std::stoi(argv[i]);
913915
#else
914-
LOG_WARNING_LLAMA("llama.cpp was compiled without cuBLAS. It is not possible "
915-
"to set a main GPU.",
916-
{});
916+
LOG_WARNING_LLAMA(
917+
"llama.cpp was compiled without cuBLAS. It is not possible "
918+
"to set a main GPU.",
919+
{});
917920
#endif
918921
} else if (arg == "--lora") {
919922
if (++i >= argc) {
@@ -1260,7 +1263,8 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
12601263
public:
12611264
llamaCPP() {
12621265
// Some default values for now below
1263-
log_disable(); //Disable the log to file feature, reduce bloat for target system ()
1266+
log_disable(); // Disable the log to file feature, reduce bloat for target
1267+
// system ()
12641268
}
12651269
METHOD_LIST_BEGIN
12661270
// list path definitions here;
@@ -1275,6 +1279,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
12751279
std::function<void(const HttpResponsePtr &)> &&callback);
12761280
void loadModel(const HttpRequestPtr &req,
12771281
std::function<void(const HttpResponsePtr &)> &&callback);
1282+
void warmupModel();
12781283

12791284
private:
12801285
llama_server_context llama;

0 commit comments

Comments
 (0)