@@ -359,47 +359,32 @@ void llamaCPP::modelStatus(
359359 return ;
360360}
361361
362- void llamaCPP::loadModel (
363- const HttpRequestPtr &req,
364- std::function<void (const HttpResponsePtr &)> &&callback) {
365-
366- if (model_loaded) {
367- LOG_INFO << " model loaded" ;
368- Json::Value jsonResp;
369- jsonResp[" message" ] = " Model already loaded" ;
370- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
371- resp->setStatusCode (drogon::k409Conflict);
372- callback (resp);
373- return ;
374- }
375-
376- const auto &jsonBody = req->getJsonObject ();
362+ bool llamaCPP::loadModelImpl (const Json::Value &jsonBody) {
377363
378364 gpt_params params;
379365
380366 // By default will setting based on number of handlers
381367 int drogon_thread = drogon::app ().getThreadNum ();
382368 LOG_INFO << " Drogon thread is:" << drogon_thread;
383369 if (jsonBody) {
384- params.model = (* jsonBody) [" llama_model_path" ].asString ();
385- params.n_gpu_layers = (* jsonBody) .get (" ngl" , 100 ).asInt ();
386- params.n_ctx = (* jsonBody) .get (" ctx_len" , 2048 ).asInt ();
387- params.embedding = (* jsonBody) .get (" embedding" , true ).asBool ();
370+ params.model = jsonBody[" llama_model_path" ].asString ();
371+ params.n_gpu_layers = jsonBody.get (" ngl" , 100 ).asInt ();
372+ params.n_ctx = jsonBody.get (" ctx_len" , 2048 ).asInt ();
373+ params.embedding = jsonBody.get (" embedding" , true ).asBool ();
388374 // Check if n_parallel exists in jsonBody, if not, set to drogon_thread
389- params.n_batch = (* jsonBody) .get (" n_batch" , 512 ).asInt ();
390- params.n_parallel = (* jsonBody) .get (" n_parallel" , drogon_thread).asInt ();
375+ params.n_batch = jsonBody.get (" n_batch" , 512 ).asInt ();
376+ params.n_parallel = jsonBody.get (" n_parallel" , drogon_thread).asInt ();
391377 params.n_threads =
392- (*jsonBody)
393- .get (" cpu_threads" , std::thread::hardware_concurrency ())
378+ jsonBody.get (" cpu_threads" , std::thread::hardware_concurrency ())
394379 .asInt ();
395- params.cont_batching = (* jsonBody) .get (" cont_batching" , false ).asBool ();
380+ params.cont_batching = jsonBody.get (" cont_batching" , false ).asBool ();
396381
397- this ->user_prompt = (* jsonBody) .get (" user_prompt" , " USER: " ).asString ();
398- this ->ai_prompt = (* jsonBody) .get (" ai_prompt" , " ASSISTANT: " ).asString ();
382+ this ->user_prompt = jsonBody.get (" user_prompt" , " USER: " ).asString ();
383+ this ->ai_prompt = jsonBody.get (" ai_prompt" , " ASSISTANT: " ).asString ();
399384 this ->system_prompt =
400- (* jsonBody) .get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
401- this ->pre_prompt = (* jsonBody) .get (" pre_prompt" , " " ).asString ();
402- this ->repeat_last_n = (* jsonBody) .get (" repeat_last_n" , 32 ).asInt ();
385+ jsonBody.get (" system_prompt" , " ASSISTANT's RULE: " ).asString ();
386+ this ->pre_prompt = jsonBody.get (" pre_prompt" , " " ).asString ();
387+ this ->repeat_last_n = jsonBody.get (" repeat_last_n" , 32 ).asInt ();
403388 }
404389#ifdef GGML_USE_CUBLAS
405390 LOG_INFO << " Setting up GGML CUBLAS PARAMS" ;
@@ -422,25 +407,46 @@ void llamaCPP::loadModel(
422407
423408 // load the model
424409 if (!llama.load_model (params)) {
425- LOG_ERROR << " Error loading the model will exit the program" ;
426- Json::Value jsonResp;
427- jsonResp[" message" ] = " Failed to load model" ;
428- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
429- resp->setStatusCode (drogon::k500InternalServerError);
430- callback (resp);
410+ LOG_ERROR << " Error loading the model" ;
411+ return false ; // Indicate failure
431412 }
432413 llama.initialize ();
433-
434- Json::Value jsonResp;
435- jsonResp[" message" ] = " Model loaded successfully" ;
436414 model_loaded = true ;
437- auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
438-
439415 LOG_INFO << " Started background task here!" ;
440416 backgroundThread = std::thread (&llamaCPP::backgroundTask, this );
441417 warmupModel ();
418+ return true ;
419+ }
442420
443- callback (resp);
421+ void llamaCPP::loadModel (
422+ const HttpRequestPtr &req,
423+ std::function<void (const HttpResponsePtr &)> &&callback) {
424+
425+ if (model_loaded) {
426+ LOG_INFO << " model loaded" ;
427+ Json::Value jsonResp;
428+ jsonResp[" message" ] = " Model already loaded" ;
429+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
430+ resp->setStatusCode (drogon::k409Conflict);
431+ callback (resp);
432+ return ;
433+ }
434+
435+ const auto &jsonBody = req->getJsonObject ();
436+ if (!loadModelImpl (*jsonBody)) {
437+ // Error occurred during model loading
438+ Json::Value jsonResp;
439+ jsonResp[" message" ] = " Failed to load model" ;
440+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
441+ resp->setStatusCode (drogon::k500InternalServerError);
442+ callback (resp);
443+ } else {
444+ // Model loaded successfully
445+ Json::Value jsonResp;
446+ jsonResp[" message" ] = " Model loaded successfully" ;
447+ auto resp = nitro_utils::nitroHttpJsonResponse (jsonResp);
448+ callback (resp);
449+ }
444450}
445451
446452void llamaCPP::backgroundTask () {
0 commit comments