Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit a321e82

Browse files
committed
refactor handler and impl for seperation of concern
1 parent e5a973f commit a321e82

File tree

2 files changed

+58
-41
lines changed

2 files changed

+58
-41
lines changed

controllers/llamaCPP.cc

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "llama.h"
33
#include "log.h"
44
#include "utils/nitro_utils.h"
5+
#include <algorithm>
56

67
using namespace inferences;
78
using json = nlohmann::json;
@@ -23,7 +24,6 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
2324

2425
// Function to check if the model is loaded
2526
void llamaCPP::checkModelLoaded(
26-
const HttpRequestPtr &req,
2727
std::function<void(const HttpResponsePtr &)> &callback) {
2828
if (!llama.model_loaded_external) {
2929
Json::Value jsonResp;
@@ -151,10 +151,17 @@ void llamaCPP::chatCompletion(
151151
const HttpRequestPtr &req,
152152
std::function<void(const HttpResponsePtr &)> &&callback) {
153153

154+
const auto &jsonBody = req->getJsonObject();
154155
// Check if model is loaded
155-
checkModelLoaded(req, callback);
156+
checkModelLoaded(callback);
157+
158+
chatCompletionImpl(jsonBody, callback);
159+
}
160+
161+
void llamaCPP::chatCompletionImpl(
162+
std::shared_ptr<Json::Value> jsonBody,
163+
std::function<void(const HttpResponsePtr &)> &callback) {
156164

157-
const auto &jsonBody = req->getJsonObject();
158165
std::string formatted_output = pre_prompt;
159166

160167
json data;
@@ -402,17 +409,23 @@ void llamaCPP::chatCompletion(
402409
}
403410
}
404411
}
412+
405413
void llamaCPP::embedding(
406414
const HttpRequestPtr &req,
407415
std::function<void(const HttpResponsePtr &)> &&callback) {
408-
checkModelLoaded(req, callback);
416+
checkModelLoaded(callback);
417+
const auto &jsonBody = req->getJsonObject();
409418

410-
auto state = create_inference_state(this);
419+
embeddingImpl(jsonBody, callback);
420+
return;
421+
}
411422

412-
const auto &jsonBody = req->getJsonObject();
423+
void llamaCPP::embeddingImpl(
424+
std::shared_ptr<Json::Value> jsonBody,
425+
std::function<void(const HttpResponsePtr &)> &callback) {
413426

414427
Json::Value responseData(Json::arrayValue);
415-
428+
auto state = create_inference_state(this);
416429
if (jsonBody->isMember("input")) {
417430
// If single queue is busy, we will wait if not we will just go ahead and
418431
// process and make it busy, and yet i'm aware not DRY, i have the same
@@ -464,7 +477,6 @@ void llamaCPP::embedding(
464477
resp->setBody(Json::writeString(Json::StreamWriterBuilder(), root));
465478
resp->setContentTypeString("application/json");
466479
callback(resp);
467-
return;
468480
}
469481

470482
void llamaCPP::unloadModel(
@@ -502,30 +514,30 @@ void llamaCPP::modelStatus(
502514
return;
503515
}
504516

505-
bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
517+
bool llamaCPP::loadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
506518

507519
gpt_params params;
508-
509520
// By default will setting based on number of handlers
510521
if (jsonBody) {
511-
if (!jsonBody["mmproj"].isNull()) {
522+
if (!jsonBody->operator[]("mmproj").isNull()) {
512523
LOG_INFO << "MMPROJ FILE detected, multi-model enabled!";
513-
params.mmproj = jsonBody["mmproj"].asString();
524+
params.mmproj = jsonBody->operator[]("mmproj").asString();
514525
}
515-
if (!jsonBody["grp_attn_n"].isNull()) {
526+
if (!jsonBody->operator[]("grp_attn_n").isNull()) {
516527

517-
params.grp_attn_n = jsonBody["grp_attn_n"].asInt();
528+
params.grp_attn_n = jsonBody->operator[]("grp_attn_n").asInt();
518529
}
519-
if (!jsonBody["grp_attn_w"].isNull()) {
530+
if (!jsonBody->operator[]("grp_attn_w").isNull()) {
520531

521-
params.grp_attn_w = jsonBody["grp_attn_w"].asInt();
532+
params.grp_attn_w = jsonBody->operator[]("grp_attn_w").asInt();
522533
}
523-
if (!jsonBody["mlock"].isNull()) {
524-
params.use_mlock = jsonBody["mlock"].asBool();
534+
if (!jsonBody->operator[]("mlock").isNull()) {
535+
params.use_mlock = jsonBody->operator[]("mlock").asBool();
525536
}
526537

527-
if (!jsonBody["grammar_file"].isNull()) {
528-
std::string grammar_file = jsonBody["grammar_file"].asString();
538+
if (!jsonBody->operator[]("grammar_file").isNull()) {
539+
std::string grammar_file =
540+
jsonBody->operator[]("grammar_file").asString();
529541
std::ifstream file(grammar_file);
530542
if (!file) {
531543
LOG_ERROR << "Grammar file not found";
@@ -536,30 +548,31 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
536548
}
537549
};
538550

539-
params.model = jsonBody["llama_model_path"].asString();
540-
params.n_gpu_layers = jsonBody.get("ngl", 100).asInt();
541-
params.n_ctx = jsonBody.get("ctx_len", 2048).asInt();
542-
params.embedding = jsonBody.get("embedding", true).asBool();
551+
params.model = jsonBody->operator[]("llama_model_path").asString();
552+
params.n_gpu_layers = jsonBody->get("ngl", 100).asInt();
553+
params.n_ctx = jsonBody->get("ctx_len", 2048).asInt();
554+
params.embedding = jsonBody->get("embedding", true).asBool();
543555
// Check if n_parallel exists in jsonBody, if not, set to drogon_thread
544-
params.n_batch = jsonBody.get("n_batch", 512).asInt();
545-
params.n_parallel = jsonBody.get("n_parallel", 1).asInt();
556+
params.n_batch = jsonBody->get("n_batch", 512).asInt();
557+
params.n_parallel = jsonBody->get("n_parallel", 1).asInt();
546558
params.n_threads =
547-
jsonBody.get("cpu_threads", std::thread::hardware_concurrency())
559+
jsonBody->get("cpu_threads", std::thread::hardware_concurrency())
548560
.asInt();
549-
params.cont_batching = jsonBody.get("cont_batching", false).asBool();
561+
params.cont_batching = jsonBody->get("cont_batching", false).asBool();
550562
this->clean_cache_threshold =
551-
jsonBody.get("clean_cache_threshold", 5).asInt();
552-
this->caching_enabled = jsonBody.get("caching_enabled", false).asBool();
553-
this->user_prompt = jsonBody.get("user_prompt", "USER: ").asString();
554-
this->ai_prompt = jsonBody.get("ai_prompt", "ASSISTANT: ").asString();
563+
jsonBody->get("clean_cache_threshold", 5).asInt();
564+
this->caching_enabled = jsonBody->get("caching_enabled", false).asBool();
565+
this->user_prompt = jsonBody->get("user_prompt", "USER: ").asString();
566+
this->ai_prompt = jsonBody->get("ai_prompt", "ASSISTANT: ").asString();
555567
this->system_prompt =
556-
jsonBody.get("system_prompt", "ASSISTANT's RULE: ").asString();
557-
this->pre_prompt = jsonBody.get("pre_prompt", "").asString();
558-
this->repeat_last_n = jsonBody.get("repeat_last_n", 32).asInt();
568+
jsonBody->get("system_prompt", "ASSISTANT's RULE: ").asString();
569+
this->pre_prompt = jsonBody->get("pre_prompt", "").asString();
570+
this->repeat_last_n = jsonBody->get("repeat_last_n", 32).asInt();
559571

560-
if (!jsonBody["llama_log_folder"].isNull()) {
572+
if (!jsonBody->operator[]("llama_log_folder").isNull()) {
561573
log_enable();
562-
std::string llama_log_folder = jsonBody["llama_log_folder"].asString();
574+
std::string llama_log_folder =
575+
jsonBody->operator[]("llama_log_folder").asString();
563576
log_set_target(llama_log_folder + "llama.log");
564577
} // Set folder for llama log
565578
}
@@ -612,7 +625,7 @@ void llamaCPP::loadModel(
612625
}
613626

614627
const auto &jsonBody = req->getJsonObject();
615-
if (!loadModelImpl(*jsonBody)) {
628+
if (!loadModelImpl(jsonBody)) {
616629
// Error occurred during model loading
617630
Json::Value jsonResp;
618631
jsonResp["message"] = "Failed to load model";

controllers/llamaCPP.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,11 +2569,15 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
25692569
// condition n_parallel is 1
25702570
std::string grammar_file_content;
25712571

2572-
bool loadModelImpl(const Json::Value &jsonBody);
2572+
bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
2573+
void
2574+
chatCompletionImpl(std::shared_ptr<Json::Value> jsonBody,
2575+
std::function<void(const HttpResponsePtr &)> &callback);
2576+
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
2577+
std::function<void(const HttpResponsePtr &)> &callback);
2578+
void checkModelLoaded(std::function<void(const HttpResponsePtr &)> &callback);
25732579
void warmupModel();
25742580
void backgroundTask();
25752581
void stopBackgroundTask();
2576-
void checkModelLoaded(const HttpRequestPtr &req,
2577-
std::function<void(const HttpResponsePtr &)> &callback);
25782582
};
25792583
}; // namespace inferences

0 commit comments

Comments
 (0)