22#include " llama.h"
33#include " log.h"
44#include " utils/nitro_utils.h"
5+ #include < algorithm>
56
67using namespace inferences ;
78using json = nlohmann::json;
@@ -23,7 +24,6 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2324
2425// Function to check if the model is loaded
2526void llamaCPP::checkModelLoaded (
26- const HttpRequestPtr &req,
2727 std::function<void (const HttpResponsePtr &)> &callback) {
2828 if (!llama.model_loaded_external ) {
2929 Json::Value jsonResp;
@@ -151,10 +151,17 @@ void llamaCPP::chatCompletion(
151151 const HttpRequestPtr &req,
152152 std::function<void (const HttpResponsePtr &)> &&callback) {
153153
154+ const auto &jsonBody = req->getJsonObject ();
154155 // Check if model is loaded
155- checkModelLoaded (req, callback);
156+ checkModelLoaded (callback);
157+
158+ chatCompletionImpl (jsonBody, callback);
159+ }
160+
161+ void llamaCPP::chatCompletionImpl (
162+ std::shared_ptr<Json::Value> jsonBody,
163+ std::function<void (const HttpResponsePtr &)> &callback) {
156164
157- const auto &jsonBody = req->getJsonObject ();
158165 std::string formatted_output = pre_prompt;
159166
160167 json data;
@@ -402,17 +409,23 @@ void llamaCPP::chatCompletion(
402409 }
403410 }
404411}
412+
405413void llamaCPP::embedding (
406414 const HttpRequestPtr &req,
407415 std::function<void (const HttpResponsePtr &)> &&callback) {
408- checkModelLoaded (req, callback);
416+ checkModelLoaded (callback);
417+ const auto &jsonBody = req->getJsonObject ();
409418
410- auto state = create_inference_state (this );
419+ embeddingImpl (jsonBody, callback);
420+ return ;
421+ }
411422
412- const auto &jsonBody = req->getJsonObject ();
423+ void llamaCPP::embeddingImpl (
424+ std::shared_ptr<Json::Value> jsonBody,
425+ std::function<void (const HttpResponsePtr &)> &callback) {
413426
414427 Json::Value responseData (Json::arrayValue);
415-
428+ auto state = create_inference_state ( this );
416429 if (jsonBody->isMember (" input" )) {
417430 // If single queue is busy, we will wait if not we will just go ahead and
418431 // process and make it busy, and yet i'm aware not DRY, i have the same
@@ -464,7 +477,6 @@ void llamaCPP::embedding(
464477 resp->setBody (Json::writeString (Json::StreamWriterBuilder (), root));
465478 resp->setContentTypeString (" application/json" );
466479 callback (resp);
467- return ;
468480}
469481
470482void llamaCPP::unloadModel (
@@ -502,30 +514,30 @@ void llamaCPP::modelStatus(
502514 return ;
503515}
504516
505- bool llamaCPP::loadModelImpl (const Json::Value & jsonBody) {
517+ bool llamaCPP::loadModelImpl (std::shared_ptr< Json::Value> jsonBody) {
506518
507519 gpt_params params;
508-
509520 // By default will setting based on number of handlers
510521 if (jsonBody) {
511- if (!jsonBody[ " mmproj" ] .isNull ()) {
522+ if (!jsonBody-> operator []( " mmproj" ) .isNull ()) {
512523 LOG_INFO << " MMPROJ FILE detected, multi-model enabled!" ;
513- params.mmproj = jsonBody[ " mmproj" ] .asString ();
524+ params.mmproj = jsonBody-> operator []( " mmproj" ) .asString ();
514525 }
515- if (!jsonBody[ " grp_attn_n" ] .isNull ()) {
526+ if (!jsonBody-> operator []( " grp_attn_n" ) .isNull ()) {
516527
517- params.grp_attn_n = jsonBody[ " grp_attn_n" ] .asInt ();
528+ params.grp_attn_n = jsonBody-> operator []( " grp_attn_n" ) .asInt ();
518529 }
519- if (!jsonBody[ " grp_attn_w" ] .isNull ()) {
530+ if (!jsonBody-> operator []( " grp_attn_w" ) .isNull ()) {
520531
521- params.grp_attn_w = jsonBody[ " grp_attn_w" ] .asInt ();
532+ params.grp_attn_w = jsonBody-> operator []( " grp_attn_w" ) .asInt ();
522533 }
523- if (!jsonBody[ " mlock" ] .isNull ()) {
524- params.use_mlock = jsonBody[ " mlock" ] .asBool ();
534+ if (!jsonBody-> operator []( " mlock" ) .isNull ()) {
535+ params.use_mlock = jsonBody-> operator []( " mlock" ) .asBool ();
525536 }
526537
527- if (!jsonBody[" grammar_file" ].isNull ()) {
528- std::string grammar_file = jsonBody[" grammar_file" ].asString ();
538+ if (!jsonBody->operator [](" grammar_file" ).isNull ()) {
539+ std::string grammar_file =
540+ jsonBody->operator [](" grammar_file" ).asString ();
529541 std::ifstream file (grammar_file);
530542 if (!file) {
531543 LOG_ERROR << " Grammar file not found" ;
@@ -536,30 +548,31 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
536548 }
537549 };
538550
539- params.model = jsonBody[ " llama_model_path" ] .asString ();
540- params.n_gpu_layers = jsonBody. get (" ngl" , 100 ).asInt ();
541- params.n_ctx = jsonBody. get (" ctx_len" , 2048 ).asInt ();
542- params.embedding = jsonBody. get (" embedding" , true ).asBool ();
551+ params.model = jsonBody-> operator []( " llama_model_path" ) .asString ();
552+ params.n_gpu_layers = jsonBody-> get (" ngl" , 100 ).asInt ();
553+ params.n_ctx = jsonBody-> get (" ctx_len" , 2048 ).asInt ();
554+ params.embedding = jsonBody-> get (" embedding" , true ).asBool ();
543555 // Check if n_parallel exists in jsonBody, if not, set to drogon_thread
544- params.n_batch = jsonBody. get (" n_batch" , 512 ).asInt ();
545- params.n_parallel = jsonBody. get (" n_parallel" , 1 ).asInt ();
556+ params.n_batch = jsonBody-> get (" n_batch" , 512 ).asInt ();
557+ params.n_parallel = jsonBody-> get (" n_parallel" , 1 ).asInt ();
546558 params.n_threads =
547- jsonBody. get (" cpu_threads" , std::thread::hardware_concurrency ())
559+ jsonBody-> get (" cpu_threads" , std::thread::hardware_concurrency ())
548560 .asInt ();
549- params.cont_batching = jsonBody. get (" cont_batching" , false ).asBool ();
561+ params.cont_batching = jsonBody-> get (" cont_batching" , false ).asBool ();
550562 this ->clean_cache_threshold =
551- jsonBody. get (" clean_cache_threshold" , 5 ).asInt ();
552- this ->caching_enabled = jsonBody. get (" caching_enabled" , false ).asBool ();
553- this ->user_prompt = jsonBody. get (" user_prompt" , " USER: " ).asString ();
554- this ->ai_prompt = jsonBody. get (" ai_prompt" , " ASSISTANT: " ).asString ();
563+ jsonBody-> get (" clean_cache_threshold" , 5 ).asInt ();
564+ this ->caching_enabled = jsonBody-> get (" caching_enabled" , false ).asBool ();
565+ this ->user_prompt = jsonBody-> get (" user_prompt" , " USER: " ).asString ();
566+ this ->ai_prompt = jsonBody-> get (" ai_prompt" , " ASSISTANT: " ).asString ();
555567 this ->system_prompt =
556- jsonBody. get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
557- this ->pre_prompt = jsonBody. get (" pre_prompt" , " " ).asString ();
558- this ->repeat_last_n = jsonBody. get (" repeat_last_n" , 32 ).asInt ();
568+ jsonBody-> get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
569+ this ->pre_prompt = jsonBody-> get (" pre_prompt" , " " ).asString ();
570+ this ->repeat_last_n = jsonBody-> get (" repeat_last_n" , 32 ).asInt ();
559571
560- if (!jsonBody[ " llama_log_folder" ] .isNull ()) {
572+ if (!jsonBody-> operator []( " llama_log_folder" ) .isNull ()) {
561573 log_enable ();
562- std::string llama_log_folder = jsonBody[" llama_log_folder" ].asString ();
574+ std::string llama_log_folder =
575+ jsonBody->operator [](" llama_log_folder" ).asString ();
563576 log_set_target (llama_log_folder + " llama.log" );
564577 } // Set folder for llama log
565578 }
@@ -612,7 +625,7 @@ void llamaCPP::loadModel(
612625 }
613626
614627 const auto &jsonBody = req->getJsonObject ();
615- if (!loadModelImpl (* jsonBody)) {
628+ if (!loadModelImpl (jsonBody)) {
616629 // Error occurred during model loading
617630 Json::Value jsonResp;
618631 jsonResp[" message" ] = " Failed to load model" ;
0 commit comments