Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit c893b4b

Browse files
Feat/stream request python engine (#1829)
* chore: add document * feat: update engine interface * chore: add document * feat: update engine interface * Feat: init python engine * Fix: conflict * feat: add python engine implementation * Fix: CI build window * Fix: CI build window * feat: support download python model from cortexso * feat: add inference interface * feat: integrate to cortex cpp * fix: remove pythone engine load engine option * Feat: init environment interface * feat: move virtual environment inside model * Update CMakeLists.txt * Update CMakeLists.txt * fix: CI build * fix: move log of python to cortex logs folder * fix: unitest for remote engine because change location of template renderer * fix: CI build windows * fix: CI build windows * feat: add depends model.yml for python engine * fix: CI build * stream response * update set permission api * Fix: comment * Feat: stream response * fix: run concurrent request with stream mode * Fix: remove unnecessary interface * Fix comment * Fix: comment review * fix comment * fix comment --------- Co-authored-by: James <namnh0122@gmail.com>
1 parent 22ff0a1 commit c893b4b

File tree

3 files changed

+127
-28
lines changed

3 files changed

+127
-28
lines changed

engine/controllers/server.cc

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ void server::FineTuning(
129129

130130
void server::Inference(const HttpRequestPtr& req,
131131
std::function<void(const HttpResponsePtr&)>&& callback) {
132+
133+
auto json_body = req->getJsonObject();
134+
132135
LOG_TRACE << "Start inference";
133136
auto q = std::make_shared<SyncQueue>();
134137
auto ir = inference_svc_->HandleInference(q, req->getJsonObject());
@@ -141,20 +144,34 @@ void server::Inference(const HttpRequestPtr& req,
141144
callback(resp);
142145
return;
143146
}
147+
148+
bool is_stream =
149+
(*json_body).get("stream", false).asBool() ||
150+
(*json_body).get("body", Json::Value()).get("stream", false).asBool();
151+
144152
LOG_TRACE << "Wait to inference";
145-
auto [status, res] = q->wait_and_pop();
146-
LOG_DEBUG << "response: " << res.toStyledString();
147-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
148-
resp->setStatusCode(
149-
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
150-
callback(resp);
151-
LOG_TRACE << "Done inference";
153+
if (is_stream) {
154+
auto model_id = (*json_body).get("model", "invalid_model").asString();
155+
auto engine_type = [this, &json_body]() -> std::string {
156+
if (!inference_svc_->HasFieldInReq(json_body, "engine")) {
157+
return kLlamaRepo;
158+
} else {
159+
return (*(json_body)).get("engine", kLlamaRepo).asString();
160+
}
161+
}();
162+
ProcessStreamRes(callback, q, engine_type, model_id);
163+
} else {
164+
ProcessNonStreamRes(callback, *q);
165+
LOG_TRACE << "Done inference";
166+
}
152167
}
153168

154169
void server::RouteRequest(
155170
const HttpRequestPtr& req,
156171
std::function<void(const HttpResponsePtr&)>&& callback) {
157172

173+
auto json_body = req->getJsonObject();
174+
158175
LOG_TRACE << "Start route request";
159176
auto q = std::make_shared<SyncQueue>();
160177
auto ir = inference_svc_->HandleRouteRequest(q, req->getJsonObject());
@@ -167,14 +184,26 @@ void server::RouteRequest(
167184
callback(resp);
168185
return;
169186
}
187+
auto is_stream =
188+
(*json_body).get("stream", false).asBool() ||
189+
(*json_body).get("body", Json::Value()).get("stream", false).asBool();
170190
LOG_TRACE << "Wait to route request";
171-
auto [status, res] = q->wait_and_pop();
172-
LOG_DEBUG << "response: " << res.toStyledString();
173-
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
174-
resp->setStatusCode(
175-
static_cast<drogon::HttpStatusCode>(status["status_code"].asInt()));
176-
callback(resp);
177-
LOG_TRACE << "Done route request";
191+
if (is_stream) {
192+
193+
auto model_id = (*json_body).get("model", "invalid_model").asString();
194+
auto engine_type = [this, &json_body]() -> std::string {
195+
if (!inference_svc_->HasFieldInReq(json_body, "engine")) {
196+
return kLlamaRepo;
197+
} else {
198+
return (*(json_body)).get("engine", kLlamaRepo).asString();
199+
}
200+
}();
201+
ProcessStreamRes(callback, q, engine_type, model_id);
202+
} else {
203+
ProcessNonStreamRes(callback, *q);
204+
LOG_TRACE << "Done route request";
205+
}
206+
178207
}
179208

180209
void server::LoadModel(const HttpRequestPtr& req,

engine/extensions/python-engine/python_engine.cc

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb,
1616
return size * nmemb;
1717
}
1818

19-
PythonEngine::PythonEngine() {}
19+
PythonEngine::PythonEngine() : q_(4 /*n_parallel*/, "python_engine") {}
20+
2021

2122
PythonEngine::~PythonEngine() {
2223
curl_global_cleanup();
@@ -169,7 +170,7 @@ bool PythonEngine::TerminateModelProcess(const std::string& model) {
169170
}
170171
CurlResponse PythonEngine::MakeGetRequest(const std::string& model,
171172
const std::string& path) {
172-
auto config = models_[model];
173+
auto const& config = models_[model];
173174
std::string full_url = "http://localhost:" + config.port + path;
174175
CurlResponse response;
175176

@@ -184,7 +185,7 @@ CurlResponse PythonEngine::MakeGetRequest(const std::string& model,
184185
}
185186
CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model,
186187
const std::string& path) {
187-
auto config = models_[model];
188+
auto const& config = models_[model];
188189
std::string full_url = "http://localhost:" + config.port + path;
189190
CurlResponse response;
190191

@@ -203,7 +204,7 @@ CurlResponse PythonEngine::MakeDeleteRequest(const std::string& model,
203204
CurlResponse PythonEngine::MakePostRequest(const std::string& model,
204205
const std::string& path,
205206
const std::string& body) {
206-
auto config = models_[model];
207+
auto const& config = models_[model];
207208
std::string full_url = "http://localhost:" + config.port + path;
208209

209210
CurlResponse response;
@@ -450,6 +451,63 @@ void PythonEngine::HandleChatCompletion(
450451
std::shared_ptr<Json::Value> json_body,
451452
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {}
452453

454+
CurlResponse PythonEngine::MakeStreamPostRequest(
455+
const std::string& model, const std::string& path, const std::string& body,
456+
const std::function<void(Json::Value&&, Json::Value&&)>& callback) {
457+
auto const& config = models_[model];
458+
CURL* curl = curl_easy_init();
459+
CurlResponse response;
460+
461+
if (!curl) {
462+
response.error = true;
463+
response.error_message = "Failed to initialize CURL";
464+
return response;
465+
}
466+
467+
std::string full_url = "http://localhost:" + config.port + path;
468+
469+
struct curl_slist* headers = nullptr;
470+
headers = curl_slist_append(headers, "Content-Type: application/json");
471+
headers = curl_slist_append(headers, "Accept: text/event-stream");
472+
headers = curl_slist_append(headers, "Cache-Control: no-cache");
473+
headers = curl_slist_append(headers, "Connection: keep-alive");
474+
475+
StreamContext context{
476+
std::make_shared<std::function<void(Json::Value&&, Json::Value&&)>>(
477+
callback),
478+
""};
479+
480+
curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str());
481+
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
482+
curl_easy_setopt(curl, CURLOPT_POST, 1L);
483+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str());
484+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback);
485+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context);
486+
curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L);
487+
488+
CURLcode res = curl_easy_perform(curl);
489+
490+
if (res != CURLE_OK) {
491+
response.error = true;
492+
response.error_message = curl_easy_strerror(res);
493+
494+
Json::Value status;
495+
status["is_done"] = true;
496+
status["has_error"] = true;
497+
status["is_stream"] = true;
498+
status["status_code"] = 500;
499+
500+
Json::Value error;
501+
error["error"] = response.error_message;
502+
callback(std::move(status), std::move(error));
503+
}
504+
505+
curl_slist_free_all(headers);
506+
curl_easy_cleanup(curl);
507+
return response;
508+
}
509+
510+
453511
void PythonEngine::HandleInference(
454512
std::shared_ptr<Json::Value> json_body,
455513
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
@@ -485,7 +543,8 @@ void PythonEngine::HandleInference(
485543

486544
// Render with error handling
487545
try {
488-
transformed_request = renderer_.Render(transform_request, *json_body);
546+
transformed_request = renderer_.Render(transform_request, body);
547+
489548
} catch (const std::exception& e) {
490549
throw std::runtime_error("Template rendering error: " +
491550
std::string(e.what()));
@@ -504,7 +563,17 @@ void PythonEngine::HandleInference(
504563

505564
CurlResponse response;
506565
if (method == "post") {
507-
response = MakePostRequest(model, path, transformed_request);
566+
if (body.isMember("stream") && body["stream"].asBool()) {
567+
q_.runTaskInQueue(
568+
[this, model, path, transformed_request, cb = std::move(callback)] {
569+
MakeStreamPostRequest(model, path, transformed_request, cb);
570+
});
571+
572+
return;
573+
} else {
574+
response = MakePostRequest(model, path, transformed_request);
575+
}
576+
508577
} else if (method == "get") {
509578
response = MakeGetRequest(model, path);
510579
} else if (method == "delete") {

engine/extensions/python-engine/python_engine.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <string>
99
#include <unordered_map>
1010
#include "config/model_config.h"
11+
#include "trantor/utils/ConcurrentTaskQueue.h"
12+
1113
#include "cortex-common/EngineI.h"
1214
#include "extensions/template_renderer.h"
1315
#include "utils/file_logger.h"
@@ -44,19 +46,12 @@ static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
4446
while ((pos = context->buffer.find('\n')) != std::string::npos) {
4547
std::string line = context->buffer.substr(0, pos);
4648
context->buffer = context->buffer.substr(pos + 1);
49+
LOG_DEBUG << "line: "<<line;
4750

4851
// Skip empty lines
4952
if (line.empty() || line == "\r")
5053
continue;
5154

52-
// Remove "data: " prefix if present
53-
// if (line.substr(0, 6) == "data: ")
54-
// {
55-
// line = line.substr(6);
56-
// }
57-
58-
// Skip [DONE] message
59-
std::cout << line << std::endl;
6055
if (line == "data: [DONE]") {
6156
Json::Value status;
6257
status["is_done"] = true;
@@ -99,6 +94,8 @@ class PythonEngine : public EngineI {
9994
extensions::TemplateRenderer renderer_;
10095
std::unique_ptr<trantor::FileLogger> async_file_logger_;
10196
std::unordered_map<std::string, pid_t> processMap;
97+
trantor::ConcurrentTaskQueue q_;
98+
10299

103100
// Helper functions
104101
CurlResponse MakePostRequest(const std::string& model,
@@ -108,6 +105,10 @@ class PythonEngine : public EngineI {
108105
const std::string& path);
109106
CurlResponse MakeDeleteRequest(const std::string& model,
110107
const std::string& path);
108+
CurlResponse MakeStreamPostRequest(
109+
const std::string& model, const std::string& path,
110+
const std::string& body,
111+
const std::function<void(Json::Value&&, Json::Value&&)>& callback);
111112

112113
// Process manager functions
113114
pid_t SpawnProcess(const std::string& model,

0 commit comments

Comments
 (0)