Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 71e0dd2

Browse files
authored
Merge pull request #425 from janhq/423-feat-refactor-codebase-and-add-inferenceprovider-so-that-new-inference-engine-can-be-easily-added
423 feat refactor codebase and add inferenceprovider so that new inference engine can be easily added
2 parents 4716d37 + 571f1ca commit 71e0dd2

File tree

5 files changed

+80
-42
lines changed

5 files changed

+80
-42
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ else()
7676
endif()
7777

7878
aux_source_directory(controllers CTL_SRC)
79+
aux_source_directory(common COMMON_SRC)
7980
# aux_source_directory(filters FILTER_SRC) aux_source_directory(plugins
8081
# PLUGIN_SRC) aux_source_directory(models MODEL_SRC)
8182

@@ -86,7 +87,7 @@ aux_source_directory(controllers CTL_SRC)
8687

8788
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
8889
# ${CMAKE_CURRENT_SOURCE_DIR}/models)
89-
target_sources(${PROJECT_NAME} PRIVATE ${CTL_SRC})
90+
target_sources(${PROJECT_NAME} PRIVATE ${CTL_SRC} ${COMMON_SRC})
9091
# ${FILTER_SRC} ${PLUGIN_SRC} ${MODEL_SRC})
9192
# ##############################################################################
9293
# uncomment the following line for dynamically loading views set_property(TARGET

common/base.cc

Whitespace-only changes.

common/base.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
#include <drogon/HttpController.h>
3+
4+
using namespace drogon;
5+
6+
#pragma once
7+
#include <drogon/HttpController.h>
8+
9+
using namespace drogon;
10+
11+
class BaseProvider {
12+
public:
13+
virtual ~BaseProvider() {}
14+
15+
// General inference method
16+
virtual void
17+
inference(const HttpRequestPtr &req,
18+
std::function<void(const HttpResponsePtr &)> &&callback) = 0;
19+
20+
// Model management
21+
virtual void
22+
loadModel(const HttpRequestPtr &req,
23+
std::function<void(const HttpResponsePtr &)> &&callback) = 0;
24+
virtual void
25+
unloadModel(const HttpRequestPtr &req,
26+
std::function<void(const HttpResponsePtr &)> &&callback) = 0;
27+
virtual void
28+
modelStatus(const HttpRequestPtr &req,
29+
std::function<void(const HttpResponsePtr &)> &&callback) = 0;
30+
};
31+
32+
class ChatProvider : public BaseProvider {
33+
public:
34+
virtual ~ChatProvider() {}
35+
36+
// Implement embedding functionality specific to chat
37+
virtual void
38+
embedding(const HttpRequestPtr &req,
39+
std::function<void(const HttpResponsePtr &)> &&callback) = 0;
40+
41+
// The derived class can also override other methods if needed
42+
};

controllers/llamaCPP.cc

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ std::string create_return_json(const std::string &id, const std::string &model,
132132
return Json::writeString(writer, root);
133133
}
134134

135+
llamaCPP::llamaCPP() {
136+
// Some default values for now below
137+
log_disable(); // Disable the log to file feature, reduce bloat for
138+
// target
139+
// system ()
140+
};
141+
142+
llamaCPP::~llamaCPP() { stopBackgroundTask(); }
143+
135144
void llamaCPP::warmupModel() {
136145
json pseudo;
137146

@@ -148,29 +157,18 @@ void llamaCPP::warmupModel() {
148157
return;
149158
}
150159

151-
void llamaCPP::handlePrelight(
152-
const HttpRequestPtr &req,
153-
std::function<void(const HttpResponsePtr &)> &&callback) {
154-
auto resp = drogon::HttpResponse::newHttpResponse();
155-
resp->setStatusCode(drogon::HttpStatusCode::k200OK);
156-
resp->addHeader("Access-Control-Allow-Origin", "*");
157-
resp->addHeader("Access-Control-Allow-Methods", "POST, OPTIONS");
158-
resp->addHeader("Access-Control-Allow-Headers", "*");
159-
callback(resp);
160-
}
161-
162-
void llamaCPP::chatCompletion(
160+
void llamaCPP::inference(
163161
const HttpRequestPtr &req,
164162
std::function<void(const HttpResponsePtr &)> &&callback) {
165163

166164
const auto &jsonBody = req->getJsonObject();
167165
// Check if model is loaded
168166
checkModelLoaded(callback);
169167

170-
chatCompletionImpl(jsonBody, callback);
168+
inferenceImpl(jsonBody, callback);
171169
}
172170

173-
void llamaCPP::chatCompletionImpl(
171+
void llamaCPP::inferenceImpl(
174172
std::shared_ptr<Json::Value> jsonBody,
175173
std::function<void(const HttpResponsePtr &)> &callback) {
176174

controllers/llamaCPP.h

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#define CPPHTTPLIB_NO_EXCEPTIONS 1
2525
#endif
2626

27+
#include "common/base.h"
2728
#include "utils/json.hpp"
2829

2930
// auto generated files (update with ./deps.sh)
@@ -2510,45 +2511,42 @@ append_to_generated_text_from_generated_token_probs(llama_server_context &llama,
25102511
using namespace drogon;
25112512

25122513
namespace inferences {
2513-
class llamaCPP : public drogon::HttpController<llamaCPP> {
2514+
class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {
25142515
public:
2515-
llamaCPP() {
2516-
// Some default values for now below
2517-
log_disable(); // Disable the log to file feature, reduce bloat for
2518-
// target
2519-
// system ()
2520-
}
2521-
2522-
~llamaCPP() { stopBackgroundTask(); }
2516+
llamaCPP();
2517+
~llamaCPP();
25232518
METHOD_LIST_BEGIN
25242519
// list path definitions here;
2525-
METHOD_ADD(llamaCPP::chatCompletion, "chat_completion", Post);
2520+
METHOD_ADD(llamaCPP::inference, "chat_completion", Post);
25262521
METHOD_ADD(llamaCPP::embedding, "embedding", Post);
25272522
METHOD_ADD(llamaCPP::loadModel, "loadmodel", Post);
25282523
METHOD_ADD(llamaCPP::unloadModel, "unloadmodel", Get);
25292524
METHOD_ADD(llamaCPP::modelStatus, "modelstatus", Get);
25302525

25312526
// Openai compatible path
2532-
ADD_METHOD_TO(llamaCPP::chatCompletion, "/v1/chat/completions", Post);
2533-
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options);
2527+
ADD_METHOD_TO(llamaCPP::inference, "/v1/chat/completions", Post);
2528+
// ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options); NOTE: prelight will be added back when browser support is properly planned
25342529

25352530
ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post);
2536-
ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);
2531+
//ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);
25372532

25382533
// PATH_ADD("/llama/chat_completion", Post);
25392534
METHOD_LIST_END
2540-
void chatCompletion(const HttpRequestPtr &req,
2541-
std::function<void(const HttpResponsePtr &)> &&callback);
2542-
void handlePrelight(const HttpRequestPtr &req,
2543-
std::function<void(const HttpResponsePtr &)> &&callback);
2544-
void embedding(const HttpRequestPtr &req,
2545-
std::function<void(const HttpResponsePtr &)> &&callback);
2546-
void loadModel(const HttpRequestPtr &req,
2547-
std::function<void(const HttpResponsePtr &)> &&callback);
2548-
void unloadModel(const HttpRequestPtr &req,
2549-
std::function<void(const HttpResponsePtr &)> &&callback);
2550-
void modelStatus(const HttpRequestPtr &req,
2551-
std::function<void(const HttpResponsePtr &)> &&callback);
2535+
void
2536+
inference(const HttpRequestPtr &req,
2537+
std::function<void(const HttpResponsePtr &)> &&callback) override;
2538+
void
2539+
embedding(const HttpRequestPtr &req,
2540+
std::function<void(const HttpResponsePtr &)> &&callback) override;
2541+
void
2542+
loadModel(const HttpRequestPtr &req,
2543+
std::function<void(const HttpResponsePtr &)> &&callback) override;
2544+
void
2545+
unloadModel(const HttpRequestPtr &req,
2546+
std::function<void(const HttpResponsePtr &)> &&callback) override;
2547+
void
2548+
modelStatus(const HttpRequestPtr &req,
2549+
std::function<void(const HttpResponsePtr &)> &&callback) override;
25522550

25532551
private:
25542552
llama_server_context llama;
@@ -2569,8 +2567,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
25692567
std::string grammar_file_content;
25702568

25712569
bool loadModelImpl(std::shared_ptr<Json::Value> jsonBody);
2572-
void
2573-
chatCompletionImpl(std::shared_ptr<Json::Value> jsonBody,
2570+
void inferenceImpl(std::shared_ptr<Json::Value> jsonBody,
25742571
std::function<void(const HttpResponsePtr &)> &callback);
25752572
void embeddingImpl(std::shared_ptr<Json::Value> jsonBody,
25762573
std::function<void(const HttpResponsePtr &)> &callback);

0 commit comments

Comments
 (0)