Skip to content

Commit b78ed8a

Browse files
committed
Adapt to the new ThreadingContext API of gemma.cpp
1 parent c7f211e commit b78ed8a

12 files changed

Lines changed: 138 additions & 86 deletions

File tree

README.md

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ end
6868

6969
Show information of cgemma module.
7070

71-
### cgemma.scheduler.config
71+
### cgemma.scheduler
7272

73-
**syntax:** `<boolean>ok, <string>err = cgemma.scheduler.config(<table>options)`
73+
**syntax:** `<cgemma.scheduler>sched, <string>err = cgemma.scheduler([<table>options])`
7474

75-
Configure the backend scheduler.
75+
Create a scheduler instance.
7676

77-
A successful call returns `true`. Otherwise, it returns `false` and a string describing the error.
77+
A successful call returns a scheduler instance. Otherwise, it returns `nil` and a string describing the error.
7878

7979
Available options and default values:
8080

@@ -92,20 +92,12 @@ Available options and default values:
9292
}
9393
```
9494

95-
> [!NOTE]
96-
> This method can only be called for configuration before the backend scheduler initialization is triggered. If the backend scheduler is triggered without being configured, it will be initialized with default options.
97-
9895
### cgemma.scheduler.cpu\_topology
9996

100-
**syntax:** `<string>desc, <string>err = cgemma.scheduler.cpu_topology()`
97+
**syntax:** `<string>desc = sched:cpu_topology()`
10198

10299
Query CPU topology.
103100

104-
A successful call returns the CPU topology information. Otherwise, it returns `nil` and a string describing the error.
105-
106-
> [!NOTE]
107-
> Calling this method will trigger the backend scheduler initialization.
108-
109101
### cgemma.new
110102

111103
**syntax:** `<cgemma.instance>inst, <string>err = cgemma.new(<table>options)`
@@ -123,13 +115,14 @@ Available options:
123115
map = -1, -- Enable memory-mapping? (-1 means auto, 0 means no, 1 means yes)
124116
to_bf16 = -1, -- Convert weights to bf16? (-1 means auto, 0 means no, 1 means yes)
125117
seed = 42, -- Random seed. (default is random setting)
118+
scheduler = sched_inst, -- Instance of scheduler, if not provided a default
119+
-- scheduler will be attached.
126120
disabled_words = {...}, -- Words you don't want to generate.
127121
}
128122
```
129123

130124
> [!NOTE]
131-
> 1. If the weights file is not in the new single-file format, then `tokenizer` and `model` options are required;
132-
> 2. Calling this method will trigger the backend scheduler initialization.
125+
> If the weights file is not in the new single-file format, then `tokenizer` are required;
133126
134127
### cgemma.instance.disabled\_tokens
135128

demo/src/app.lua

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
local ok, err = require("cgemma").scheduler.config(config().scheduler)
2-
if not ok then
1+
local sched, err = require("cgemma").scheduler(config().scheduler)
2+
if not sched then
33
ngx.log(ngx.ERR, "cgemma error: ", err)
44
end
55

66
function gemma_inst()
77
if not worker_gemma_inst then
8-
local gemma, err = require("cgemma").new(config().gemma)
8+
local gemma_cfg = config().gemma
9+
gemma_cfg.scheduler = sched
10+
local gemma, err = require("cgemma").new(gemma_cfg)
911
if not gemma then
1012
ngx.log(ngx.ERR, "cgemma error: ", err)
1113
ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)

demo/src/init_kaggle.lua

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
function config()
22
return {
3-
scheduler = {},
43
gemma = {
54
tokenizer = "tokenizer.spm",
65
weights = "4b-it-sfp.sbs"

src/batch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ gcpp::TimingInfo generate(cgemma::instance* inst, const std::vector<cgemma::sess
106106
.kv_cache = ctx.sess->kv_cache()
107107
});
108108
}
109-
inst->model().GenerateBatch(cfg, queries, inst->env(), timing);
109+
inst->model().GenerateBatch(cfg, queries, inst->matmul_env(), timing);
110110
return timing;
111111
}
112112

src/cgemma.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "session.hpp"
44
#include "image_tokens.hpp"
55
#include "batch.hpp"
6-
#include "scheduler.hpp"
76
#include <hwy/timer.h>
87
#include <hwy/per_target.h>
98
#include <hwy/targets.h>
@@ -41,10 +40,12 @@ int info(lua_State* L) {
4140
int luaopen_cgemma(lua_State* L) {
4241
constexpr const luaL_Reg entries[] = {
4342
{"info", info},
43+
{"scheduler", cgemma::scheduler::create},
4444
{"new", cgemma::instance::create},
4545
{"batch", cgemma::batch},
4646
{nullptr, nullptr}
4747
};
48+
cgemma::scheduler::declare(L);
4849
cgemma::instance::declare(L);
4950
cgemma::session::declare(L);
5051
cgemma::image_tokens::declare(L);
@@ -55,7 +56,5 @@ int luaopen_cgemma(lua_State* L) {
5556
lua_setfield(L, -2, "_NAME");
5657
lua_pushliteral(L, "1.0");
5758
lua_setfield(L, -2, "_VERSION");
58-
cgemma::scheduler::declare(L);
59-
lua_setfield(L, -2, "scheduler");
6059
return 1;
6160
}

src/image_tokens.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@ int create(lua_State* L) {
7474
gcpp::ImageTokens tks(
7575
"image_tokens",
7676
gcpp::Extents2D(model_cfg.vit_config.seq_len / (model_cfg.vit_config.pool_dim * model_cfg.vit_config.pool_dim), model_cfg.model_dim),
77+
inst->threading_ctx().allocator,
7778
gcpp::MatPadding::kOdd
7879
);
80+
tks.AllocateAndAttachRowPtrs(inst->matmul_env().row_ptrs);
7981
gcpp::RuntimeConfig cfg;
8082
cfg.gen = &inst->rnd();
8183
cfg.verbosity = 0;
82-
inst->model().GenerateImageTokens(cfg, tks.Rows(), img, tks, inst->env());
84+
inst->model().GenerateImageTokens(cfg, tks.Rows(), img, tks, inst->matmul_env());
8385
auto ud = lua_newuserdata(L, sizeof(gcpp::ImageTokens));
8486
new(ud) gcpp::ImageTokens(std::move(tks));
8587
luaL_getmetatable(L, name);

src/instance.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,19 @@ int disabled_tokens(lua_State* L) {
3232

3333
namespace cgemma {
3434

35-
instance::instance(int argc, char* argv[], unsigned int seed)
35+
instance::instance(int argc, char* argv[], unsigned int seed, scheduler* sched)
3636
: args_(argc, argv)
37-
, rnd_(seed) {
38-
env_ = std::make_unique<gcpp::MatMulEnv>(gcpp::ThreadingContext::Get());
37+
, rnd_(seed)
38+
, sched_(sched) {
39+
if (!sched_) {
40+
default_sched_ = std::make_unique<scheduler>();
41+
sched_ = default_sched_.get();
42+
}
3943
// Disable heuristics loading weights into BF16
4044
gcpp::InferenceArgs infa;
4145
infa.prefill_tbatch_size = 0;
4246
infa.decode_qbatch_size = 0;
43-
model_ = std::make_unique<gcpp::Gemma>(args_, infa, env_->ctx.pools);
47+
model_ = std::make_unique<gcpp::Gemma>(args_, infa, threading_ctx());
4448
}
4549

4650
bool instance::instruction_tuned() const {
@@ -120,8 +124,11 @@ int instance::create(lua_State* L) {
120124
seed = rd();
121125
}
122126
lua_pop(L, 1);
127+
lua_getfield(L, 1, "scheduler");
128+
auto sched = scheduler::to(L, -1);
129+
lua_pop(L, 1);
123130
auto ud = lua_newuserdata(L, sizeof(instance));
124-
auto inst = new(ud) instance(argc, argv, seed);
131+
auto inst = new(ud) instance(argc, argv, seed, sched);
125132
luaL_getmetatable(L, name);
126133
lua_setmetatable(L, -2);
127134
lua_getfield(L, 1, "disabled_words");

src/instance.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef CGEMMA_INSTANCE_HPP
22
#define CGEMMA_INSTANCE_HPP
33

4-
#include <lua.hpp>
4+
#include "scheduler.hpp"
55
#include <gemma/gemma.h>
66
#include <gemma/gemma_args.h>
77
#include <unordered_set>
@@ -13,15 +13,14 @@ namespace cgemma {
1313
constexpr const int PAD_ID = 0;
1414
constexpr const int UNK_ID = 3;
1515

16-
class session;
17-
1816
class instance {
1917
public:
20-
instance(int argc, char* argv[], unsigned int seed);
18+
instance(int argc, char* argv[], unsigned int seed, scheduler* sched);
2119

2220
const gcpp::LoaderArgs& args() const { return args_; }
2321
std::mt19937& rnd() { return rnd_; }
24-
gcpp::MatMulEnv& env() const { return *env_; }
22+
gcpp::ThreadingContext& threading_ctx() const { return sched_->threading_ctx(); }
23+
gcpp::MatMulEnv& matmul_env() const { return sched_->matmul_env(); }
2524
gcpp::Gemma& model() const { return *model_; }
2625
const std::unordered_set<int>& disabled_tokens() const { return disabled_tokens_; }
2726
size_t max_tokens() const { return model_->GetModelConfig().max_seq_len; }
@@ -35,7 +34,8 @@ class instance {
3534
private:
3635
gcpp::LoaderArgs args_;
3736
std::mt19937 rnd_;
38-
std::unique_ptr<gcpp::MatMulEnv> env_;
37+
scheduler* sched_;
38+
std::unique_ptr<scheduler> default_sched_;
3939
std::unique_ptr<gcpp::Gemma> model_;
4040
std::unordered_set<int> disabled_tokens_;
4141
};

src/scheduler.cpp

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,61 @@ namespace {
66

77
constexpr const char name[] = "cgemma.scheduler";
88

9-
int config(lua_State* L) {
10-
if (gcpp::ThreadingContext::IsInitialized()) {
11-
lua_pushnil(L);
12-
lua_pushstring(L, "Scheduler had been initialized.");
13-
return 2;
14-
}
9+
int cpu_topology(lua_State* L) {
10+
auto sched = cgemma::scheduler::check(L, 1);
11+
lua_pushstring(L, sched->cpu_topology());
12+
return 1;
13+
}
14+
15+
int destroy(lua_State* L) {
16+
cgemma::scheduler::check(L, 1)->~scheduler();
17+
return 0;
18+
}
19+
20+
}
21+
22+
namespace cgemma {
23+
24+
scheduler::scheduler()
25+
: ctx_(args_)
26+
, env_(ctx_) {
27+
// nop
28+
}
29+
30+
scheduler::scheduler(int args, char* argv[])
31+
: args_(args, argv)
32+
, ctx_(args_)
33+
, env_(ctx_) {
34+
// nop
35+
}
36+
37+
void scheduler::declare(lua_State* L) {
38+
constexpr const luaL_Reg metatable[] = {
39+
{"__gc", destroy},
40+
{nullptr, nullptr}
41+
};
42+
constexpr const luaL_Reg methods[] = {
43+
{"cpu_topology", ::cpu_topology},
44+
{nullptr, nullptr}
45+
};
46+
luaL_newmetatable(L, name);
47+
luaL_register(L, nullptr, metatable);
48+
lua_pushlstring(L, name, sizeof(name) - 1);
49+
lua_setfield(L, -2, "_NAME");
50+
lua_newtable(L);
51+
luaL_register(L, nullptr, methods);
52+
lua_setfield(L, -2, "__index");
53+
}
54+
55+
scheduler* scheduler::to(lua_State* L, int index) {
56+
return static_cast<scheduler*>(utils::userdata(L, index, name));
57+
}
58+
59+
scheduler* scheduler::check(lua_State* L, int index) {
60+
return static_cast<scheduler*>(luaL_checkudata(L, index, name));
61+
}
62+
63+
int scheduler::create(lua_State* L) {
1564
constexpr const char* available_options[] = {
1665
"--num_threads", "--pin", "--bind",
1766
"--skip_packages", "--max_packages",
@@ -21,52 +70,32 @@ int config(lua_State* L) {
2170
constexpr const int n = sizeof(available_options) / sizeof(available_options[0]);
2271
int argc = 1;
2372
char* argv[n * 2 + 1] = {const_cast<char*>("lua-cgemma")};
24-
luaL_checktype(L, 1, LUA_TTABLE);
25-
for (auto opt: available_options) {
26-
auto k = opt + 2;
27-
lua_getfield(L, 1, k);
28-
auto v = lua_tostring(L, -1);
29-
if (v) {
30-
argv[argc++] = const_cast<char*>(opt);
31-
argv[argc++] = const_cast<char*>(v);
73+
auto nargs = lua_gettop(L);
74+
if (nargs > 0) {
75+
luaL_checktype(L, 1, LUA_TTABLE);
76+
for (auto opt: available_options) {
77+
auto k = opt + 2;
78+
lua_getfield(L, 1, k);
79+
auto v = lua_tostring(L, -1);
80+
if (v) {
81+
argv[argc++] = const_cast<char*>(opt);
82+
argv[argc++] = const_cast<char*>(v);
83+
}
84+
lua_pop(L, 1);
3285
}
33-
lua_pop(L, 1);
3486
}
35-
gcpp::ThreadingContext::SetArgs(gcpp::ThreadingArgs(argc, argv));
36-
if (gcpp::ThreadingContext::IsInitialized()) {
37-
lua_pushnil(L);
38-
lua_pushstring(L, "Scheduler had been initialized.");
39-
return 2;
40-
}
41-
lua_pushboolean(L, 1);
42-
return 1;
43-
}
44-
45-
int cpu_topology(lua_State* L) {
87+
auto ud = lua_newuserdata(L, sizeof(scheduler));
4688
try {
47-
lua_pushstring(L, gcpp::ThreadingContext::Get().topology.TopologyString());
89+
new(ud) scheduler(argc, argv);
90+
luaL_getmetatable(L, name);
91+
lua_setmetatable(L, -2);
4892
return 1;
4993
} catch (const std::exception& e) {
94+
lua_pop(L, 1);
5095
lua_pushnil(L);
5196
lua_pushstring(L, e.what());
5297
return 2;
5398
}
5499
}
55100

56101
}
57-
58-
namespace cgemma { namespace scheduler {
59-
60-
void declare(lua_State* L) {
61-
constexpr const luaL_Reg entries[] = {
62-
{"config", config},
63-
{"cpu_topology", cpu_topology},
64-
{nullptr, nullptr}
65-
};
66-
lua_newtable(L);
67-
luaL_register(L, nullptr, entries);
68-
lua_pushliteral(L, "cgemma.scheduler");
69-
lua_setfield(L, -2, "_NAME");
70-
}
71-
72-
} }

src/scheduler.hpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,31 @@
22
#define CGEMMA_SCHEDULER_HPP
33

44
#include <lua.hpp>
5+
#include <util/threading_context.h>
6+
#include <ops/matmul.h>
57

6-
namespace cgemma { namespace scheduler {
8+
namespace cgemma {
79

8-
void declare(lua_State* L);
10+
class scheduler {
11+
public:
12+
scheduler();
13+
scheduler(int argc, char* argv[]);
914

10-
} }
15+
const char* cpu_topology() const { return ctx_.topology.TopologyString(); }
16+
gcpp::ThreadingContext& threading_ctx() { return ctx_; }
17+
gcpp::MatMulEnv& matmul_env() { return env_; }
18+
19+
static void declare(lua_State* L);
20+
static scheduler* to(lua_State* L, int index);
21+
static scheduler* check(lua_State* L, int index);
22+
static int create(lua_State* L);
23+
24+
private:
25+
gcpp::ThreadingArgs args_;
26+
gcpp::ThreadingContext ctx_;
27+
gcpp::MatMulEnv env_;
28+
};
29+
30+
}
1131

1232
#endif // CGEMMA_SCHEDULER_HPP

0 commit comments

Comments
 (0)