Skip to content

Commit 017a7e3

Browse files
ggerganovshaobo.xie
authored andcommitted
common : refactor common_sampler + grammar logic changes (#17937)
* common : refactor common_sampler + grammar logic changes * tests : increase max_tokens to get needed response * batched : fix uninitialized samplers
1 parent 2525867 commit 017a7e3

27 files changed

Lines changed: 370 additions & 291 deletions

File tree

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14151415
params.sampling.top_k = value;
14161416
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
14171417
}
1418-
).set_sparam());
1418+
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
14191419
add_opt(common_arg(
14201420
{"--top-p"}, "N",
14211421
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),

common/common.cpp

Lines changed: 125 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,31 +1014,40 @@ bool tty_can_use_colors() {
10141014
// Model utils
10151015
//
10161016

1017-
static inline void common_init_sampler_from_model(
1017+
// TODO: move to common/sampling
1018+
static void common_init_sampler_from_model(
10181019
const llama_model * model,
10191020
common_params_sampling & sparams) {
10201021

10211022
const uint64_t config = sparams.user_sampling_config;
10221023

10231024
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1024-
if (config & user_config) return;
1025+
if (config & user_config) {
1026+
return;
1027+
}
10251028

10261029
char buf[64] = {0};
10271030
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
10281031
char * end = nullptr;
10291032
int32_t v = strtol(buf, &end, 10);
1030-
if (end && end != buf) dst = v;
1033+
if (end && end != buf) {
1034+
dst = v;
1035+
}
10311036
}
10321037
};
10331038

10341039
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1035-
if (config & user_config) return;
1040+
if (config & user_config) {
1041+
return;
1042+
}
10361043

10371044
char buf[128] = {0};
10381045
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
10391046
char * end = nullptr;
10401047
float v = strtof(buf, &end);
1041-
if (end && end != buf) dst = v;
1048+
if (end && end != buf) {
1049+
dst = v;
1050+
}
10421051
}
10431052
};
10441053

@@ -1066,31 +1075,122 @@ static inline void common_init_sampler_from_model(
10661075
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
10671076
}
10681077

1069-
struct common_init_result common_init_from_params(common_params & params) {
1070-
common_init_result iparams;
1071-
auto mparams = common_model_params_to_llama(params);
1078+
struct common_init_result::impl {
1079+
impl() = default;
1080+
~impl() = default;
1081+
1082+
llama_model_ptr model;
1083+
llama_context_ptr context;
1084+
1085+
std::vector<llama_adapter_lora_ptr> lora;
1086+
1087+
std::vector<common_sampler_ptr> samplers;
1088+
};
1089+
1090+
common_init_result::common_init_result(common_params & params) :
1091+
pimpl(new impl{}) {
1092+
const auto mparams = common_model_params_to_llama(params);
10721093

10731094
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
10741095
if (model == NULL) {
1075-
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1076-
__func__, params.model.path.c_str());
1077-
return iparams;
1096+
return;
10781097
}
10791098

1080-
common_init_sampler_from_model(model, params.sampling);
1099+
pimpl->model.reset(model);
10811100

10821101
const llama_vocab * vocab = llama_model_get_vocab(model);
10831102

1103+
// updates params.sampling
1104+
// TODO: fix naming
1105+
common_init_sampler_from_model(model, params.sampling);
1106+
10841107
auto cparams = common_context_params_to_llama(params);
10851108

1109+
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1110+
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1111+
params.sampling.ignore_eos = false;
1112+
}
1113+
1114+
// initialize once
1115+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1116+
if (llama_vocab_is_eog(vocab, i)) {
1117+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
1118+
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1119+
}
1120+
}
1121+
1122+
if (params.sampling.ignore_eos) {
1123+
// add EOG biases to the active set of logit biases
1124+
params.sampling.logit_bias.insert(
1125+
params.sampling.logit_bias.end(),
1126+
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1127+
}
1128+
1129+
//if (params.sampling.penalty_last_n == -1) {
1130+
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1131+
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
1132+
//}
1133+
1134+
//if (params.sampling.dry_penalty_last_n == -1) {
1135+
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1136+
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1137+
//}
1138+
1139+
pimpl->samplers.resize(cparams.n_seq_max);
1140+
1141+
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1142+
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
1143+
}
1144+
10861145
llama_context * lctx = llama_init_from_model(model, cparams);
1146+
if (lctx == NULL) {
1147+
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1148+
__func__, params.model.path.c_str());
1149+
return;
1150+
}
1151+
1152+
pimpl->context.reset(lctx);
1153+
}
1154+
1155+
llama_model * common_init_result::model() {
1156+
return pimpl->model.get();
1157+
}
1158+
1159+
llama_context * common_init_result::context() {
1160+
return pimpl->context.get();
1161+
}
1162+
1163+
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
1164+
return pimpl->samplers[seq_id].get();
1165+
}
1166+
1167+
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
1168+
return pimpl->lora;
1169+
}
1170+
1171+
void common_init_result::free_context() {
1172+
pimpl->context.reset();
1173+
}
1174+
1175+
common_init_result_ptr common_init_from_params(common_params & params) {
1176+
common_init_result_ptr res(new common_init_result(params));
1177+
1178+
llama_model * model = res->model();
1179+
if (model == NULL) {
1180+
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1181+
__func__, params.model.path.c_str());
1182+
return res;
1183+
}
1184+
1185+
llama_context * lctx = res->context();
10871186
if (lctx == NULL) {
10881187
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
10891188
__func__, params.model.path.c_str());
1090-
llama_model_free(model);
1091-
return iparams;
1189+
return res;
10921190
}
10931191

1192+
const llama_vocab * vocab = llama_model_get_vocab(model);
1193+
10941194
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
10951195
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
10961196
params.ctx_shift = false;
@@ -1102,10 +1202,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11021202

11031203
const auto cvec = common_control_vector_load(params.control_vectors);
11041204
if (cvec.n_embd == -1) {
1105-
llama_free(lctx);
1106-
llama_model_free(model);
1107-
1108-
return iparams;
1205+
return res;
11091206
}
11101207

11111208
int err = llama_apply_adapter_cvec(
@@ -1116,10 +1213,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11161213
params.control_vector_layer_start,
11171214
params.control_vector_layer_end);
11181215
if (err) {
1119-
llama_free(lctx);
1120-
llama_model_free(model);
1121-
1122-
return iparams;
1216+
return res;
11231217
}
11241218
}
11251219

@@ -1143,10 +1237,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11431237
}
11441238

11451239
if (!ok) {
1146-
llama_free(lctx);
1147-
llama_model_free(model);
1148-
1149-
return iparams;
1240+
return res;
11501241
}
11511242
}
11521243

@@ -1156,9 +1247,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11561247
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
11571248
if (lora == nullptr) {
11581249
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
1159-
llama_free(lctx);
1160-
llama_model_free(model);
1161-
return iparams;
1250+
return res;
11621251
}
11631252

11641253
char buf[1024];
@@ -1167,43 +1256,13 @@ struct common_init_result common_init_from_params(common_params & params) {
11671256
la.task_name = buf;
11681257
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
11691258
la.prompt_prefix = buf;
1170-
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1259+
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
11711260
}
11721261

11731262
if (!params.lora_init_without_apply) {
11741263
common_set_adapter_lora(lctx, params.lora_adapters);
11751264
}
11761265

1177-
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1178-
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1179-
params.sampling.ignore_eos = false;
1180-
}
1181-
1182-
// initialize once
1183-
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1184-
if (llama_vocab_is_eog(vocab, i)) {
1185-
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1186-
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1187-
}
1188-
}
1189-
1190-
if (params.sampling.ignore_eos) {
1191-
// add EOG biases to the active set of logit biases
1192-
params.sampling.logit_bias.insert(
1193-
params.sampling.logit_bias.end(),
1194-
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1195-
}
1196-
1197-
if (params.sampling.penalty_last_n == -1) {
1198-
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1199-
params.sampling.penalty_last_n = llama_n_ctx(lctx);
1200-
}
1201-
1202-
if (params.sampling.dry_penalty_last_n == -1) {
1203-
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1204-
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1205-
}
1206-
12071266
if (params.warmup) {
12081267
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
12091268

@@ -1282,12 +1341,11 @@ struct common_init_result common_init_from_params(common_params & params) {
12821341
}
12831342
#endif
12841343

1285-
iparams.model.reset(model);
1286-
iparams.context.reset(lctx);
1287-
1288-
return iparams;
1344+
return res;
12891345
}
12901346

1347+
common_init_result::~common_init_result() = default;
1348+
12911349
std::string get_model_endpoint() {
12921350
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
12931351
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1296,7 +1354,9 @@ std::string get_model_endpoint() {
12961354
std::string model_endpoint = "https://huggingface.co/";
12971355
if (endpoint_env) {
12981356
model_endpoint = endpoint_env;
1299-
if (model_endpoint.back() != '/') model_endpoint += '/';
1357+
if (model_endpoint.back() != '/') {
1358+
model_endpoint += '/';
1359+
}
13001360
}
13011361
return model_endpoint;
13021362
}

common/common.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ struct common_params_sampling {
195195

196196
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
197197

198-
199198
std::vector<enum common_sampler_type> samplers = {
200199
COMMON_SAMPLER_TYPE_PENALTIES,
201200
COMMON_SAMPLER_TYPE_DRY,
@@ -216,6 +215,10 @@ struct common_params_sampling {
216215
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
217216
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
218217

218+
bool has_logit_bias() const {
219+
return !logit_bias.empty();
220+
}
221+
219222
// print the parameters into a string
220223
std::string print() const;
221224
};
@@ -673,15 +676,29 @@ bool tty_can_use_colors();
673676
// Model utils
674677
//
675678

676-
// note: defines object's lifetime
679+
struct common_sampler;
680+
681+
// note: defines the model, context, samplers, ets. lifetimes
677682
struct common_init_result {
678-
llama_model_ptr model;
679-
llama_context_ptr context;
683+
common_init_result(common_params & params);
684+
~common_init_result();
680685

681-
std::vector<llama_adapter_lora_ptr> lora;
686+
llama_model * model();
687+
llama_context * context();
688+
common_sampler * sampler(llama_seq_id seq_id);
689+
690+
std::vector<llama_adapter_lora_ptr> & lora();
691+
692+
void free_context();
693+
694+
private:
695+
struct impl;
696+
std::unique_ptr<impl> pimpl;
682697
};
683698

684-
struct common_init_result common_init_from_params(common_params & params);
699+
using common_init_result_ptr = std::unique_ptr<common_init_result>;
700+
701+
common_init_result_ptr common_init_from_params(common_params & params);
685702

686703
struct llama_model_params common_model_params_to_llama ( common_params & params);
687704
struct llama_context_params common_context_params_to_llama(const common_params & params);

0 commit comments

Comments
 (0)