Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 8c63530

Browse files
authored
fix: race condition between stream and non-stream inference (#507)
1 parent 70b2a92 commit 8c63530

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

controllers/llamaCPP.cc

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ std::shared_ptr<inferenceState> create_inference_state(llamaCPP* instance) {
4949
* @param callback the function to return message to user
5050
*/
5151
bool llamaCPP::CheckModelLoaded(
52-
std::function<void(const HttpResponsePtr&)>& callback) {
52+
const std::function<void(const HttpResponsePtr&)>& callback) {
5353
if (!llama.model_loaded_external) {
5454
LOG_ERROR << "Model has not been loaded";
5555
Json::Value jsonResp;
@@ -180,13 +180,13 @@ void llamaCPP::ChatCompletion(
180180
if (CheckModelLoaded(callback)) {
181181
// Model is loaded
182182
// Do Inference
183-
InferenceImpl(std::move(completion), callback);
183+
InferenceImpl(std::move(completion), std::move(callback));
184184
}
185185
}
186186

187187
void llamaCPP::InferenceImpl(
188188
inferences::ChatCompletionRequest&& completion,
189-
std::function<void(const HttpResponsePtr&)>& callback) {
189+
std::function<void(const HttpResponsePtr&)>&& callback) {
190190
std::string formatted_output = pre_prompt;
191191
int request_id = ++no_of_requests;
192192
LOG_INFO_REQUEST(request_id) << "Generating reponse for inference request";
@@ -405,14 +405,14 @@ void llamaCPP::InferenceImpl(
405405
};
406406
// Queued task
407407
state->instance->queue->runTaskInQueue(
408-
[callback, state, data, chunked_content_provider, request_id]() {
408+
[cb = std::move(callback), state, data, chunked_content_provider, request_id]() {
409409
state->task_id =
410410
state->instance->llama.request_completion(data, false, false, -1);
411411

412412
// Start streaming response
413413
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
414414
"chat_completions.txt");
415-
callback(resp);
415+
cb(resp);
416416

417417
int retries = 0;
418418

@@ -435,28 +435,31 @@ void llamaCPP::InferenceImpl(
435435
LOG_INFO_REQUEST(request_id) << "Inference completed";
436436
});
437437
} else {
438-
Json::Value respData;
439-
int task_id = llama.request_completion(data, false, false, -1);
440-
LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone";
441-
if (!json_value(data, "stream", false)) {
442-
std::string completion_text;
443-
task_result result = llama.next_result(task_id);
444-
if (!result.error && result.stop) {
445-
int prompt_tokens = result.result_json["tokens_evaluated"];
446-
int predicted_tokens = result.result_json["tokens_predicted"];
447-
std::string to_send = result.result_json["content"];
448-
nitro_utils::ltrim(to_send);
449-
respData = create_full_return_json(
450-
nitro_utils::generate_random_string(20), "_", to_send, "_",
451-
prompt_tokens, predicted_tokens);
452-
} else {
453-
respData["message"] = "Internal error during inference";
454-
LOG_ERROR_REQUEST(request_id) << "Error during inference";
455-
}
456-
auto resp = nitro_utils::nitroHttpJsonResponse(respData);
457-
callback(resp);
458-
LOG_INFO_REQUEST(request_id) << "Inference completed";
459-
}
438+
queue->runTaskInQueue(
439+
[this, request_id, cb = std::move(callback), d = std::move(data)]() {
440+
Json::Value respData;
441+
int task_id = llama.request_completion(d, false, false, -1);
442+
LOG_INFO_REQUEST(request_id) << "Non stream, waiting for respone";
443+
if (!json_value(d, "stream", false)) {
444+
std::string completion_text;
445+
task_result result = llama.next_result(task_id);
446+
if (!result.error && result.stop) {
447+
int prompt_tokens = result.result_json["tokens_evaluated"];
448+
int predicted_tokens = result.result_json["tokens_predicted"];
449+
std::string to_send = result.result_json["content"];
450+
nitro_utils::ltrim(to_send);
451+
respData = create_full_return_json(
452+
nitro_utils::generate_random_string(20), "_", to_send, "_",
453+
prompt_tokens, predicted_tokens);
454+
} else {
455+
respData["message"] = "Internal error during inference";
456+
LOG_ERROR_REQUEST(request_id) << "Error during inference";
457+
}
458+
auto resp = nitro_utils::nitroHttpJsonResponse(respData);
459+
cb(resp);
460+
LOG_INFO_REQUEST(request_id) << "Inference completed";
461+
}
462+
});
460463
}
461464
}
462465

@@ -468,14 +471,14 @@ void llamaCPP::Embedding(
468471
// Model is loaded
469472
const auto& jsonBody = req->getJsonObject();
470473
// Run embedding
471-
EmbeddingImpl(jsonBody, callback);
474+
EmbeddingImpl(jsonBody, std::move(callback));
472475
return;
473476
}
474477
}
475478

476479
void llamaCPP::EmbeddingImpl(
477480
std::shared_ptr<Json::Value> jsonBody,
478-
std::function<void(const HttpResponsePtr&)>& callback) {
481+
std::function<void(const HttpResponsePtr&)>&& callback) {
479482
int request_id = ++no_of_requests;
480483
LOG_INFO_REQUEST(request_id) << "Generating reponse for embedding request";
481484
// Queue embedding task

controllers/llamaCPP.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ class llamaCPP : public drogon::HttpController<llamaCPP>,
9797

9898
bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
9999
void InferenceImpl(inferences::ChatCompletionRequest&& completion,
100-
std::function<void(const HttpResponsePtr&)>& callback);
100+
std::function<void(const HttpResponsePtr&)>&& callback);
101101
void EmbeddingImpl(std::shared_ptr<Json::Value> jsonBody,
102-
std::function<void(const HttpResponsePtr&)>& callback);
103-
bool CheckModelLoaded(std::function<void(const HttpResponsePtr&)>& callback);
102+
std::function<void(const HttpResponsePtr&)>&& callback);
103+
bool CheckModelLoaded(const std::function<void(const HttpResponsePtr&)>& callback);
104104
void WarmupModel();
105105
void BackgroundTask();
106106
void StopBackgroundTask();

0 commit comments

Comments
 (0)