@@ -1014,31 +1014,40 @@ bool tty_can_use_colors() {
10141014// Model utils
10151015//
10161016
1017- static inline void common_init_sampler_from_model (
1017+ // TODO: move to common/sampling
1018+ static void common_init_sampler_from_model (
10181019 const llama_model * model,
10191020 common_params_sampling & sparams) {
10201021
10211022 const uint64_t config = sparams.user_sampling_config ;
10221023
10231024 auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1024- if (config & user_config) return ;
1025+ if (config & user_config) {
1026+ return ;
1027+ }
10251028
10261029 char buf[64 ] = {0 };
10271030 if (llama_model_meta_val_str (model, key, buf, sizeof (buf)) > 0 ) {
10281031 char * end = nullptr ;
10291032 int32_t v = strtol (buf, &end, 10 );
1030- if (end && end != buf) dst = v;
1033+ if (end && end != buf) {
1034+ dst = v;
1035+ }
10311036 }
10321037 };
10331038
10341039 auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1035- if (config & user_config) return ;
1040+ if (config & user_config) {
1041+ return ;
1042+ }
10361043
10371044 char buf[128 ] = {0 };
10381045 if (llama_model_meta_val_str (model, key, buf, sizeof (buf)) > 0 ) {
10391046 char * end = nullptr ;
10401047 float v = strtof (buf, &end);
1041- if (end && end != buf) dst = v;
1048+ if (end && end != buf) {
1049+ dst = v;
1050+ }
10421051 }
10431052 };
10441053
@@ -1066,31 +1075,122 @@ static inline void common_init_sampler_from_model(
10661075 get_float (llama_model_meta_key_str (LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta , common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
10671076}
10681077
1069- struct common_init_result common_init_from_params (common_params & params) {
1070- common_init_result iparams;
1071- auto mparams = common_model_params_to_llama (params);
1078+ struct common_init_result ::impl {
1079+ impl () = default ;
1080+ ~impl () = default ;
1081+
1082+ llama_model_ptr model;
1083+ llama_context_ptr context;
1084+
1085+ std::vector<llama_adapter_lora_ptr> lora;
1086+
1087+ std::vector<common_sampler_ptr> samplers;
1088+ };
1089+
1090+ common_init_result::common_init_result (common_params & params) :
1091+ pimpl(new impl{}) {
1092+ const auto mparams = common_model_params_to_llama (params);
10721093
10731094 llama_model * model = llama_model_load_from_file (params.model .path .c_str (), mparams);
10741095 if (model == NULL ) {
1075- LOG_ERR (" %s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1076- __func__, params.model .path .c_str ());
1077- return iparams;
1096+ return ;
10781097 }
10791098
1080- common_init_sampler_from_model ( model, params. sampling );
1099+ pimpl-> model . reset (model );
10811100
10821101 const llama_vocab * vocab = llama_model_get_vocab (model);
10831102
1103+ // updates params.sampling
1104+ // TODO: fix naming
1105+ common_init_sampler_from_model (model, params.sampling );
1106+
10841107 auto cparams = common_context_params_to_llama (params);
10851108
1109+ if (params.sampling .ignore_eos && llama_vocab_eos (vocab) == LLAMA_TOKEN_NULL) {
1110+ LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
1111+ params.sampling .ignore_eos = false ;
1112+ }
1113+
1114+ // initialize once
1115+ for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1116+ if (llama_vocab_is_eog (vocab, i)) {
1117+ LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (vocab, i).c_str (), -INFINITY);
1118+ params.sampling .logit_bias_eog .push_back ({i, -INFINITY});
1119+ }
1120+ }
1121+
1122+ if (params.sampling .ignore_eos ) {
1123+ // add EOG biases to the active set of logit biases
1124+ params.sampling .logit_bias .insert (
1125+ params.sampling .logit_bias .end (),
1126+ params.sampling .logit_bias_eog .begin (), params.sampling .logit_bias_eog .end ());
1127+ }
1128+
1129+ // if (params.sampling.penalty_last_n == -1) {
1130+ // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1131+ // params.sampling.penalty_last_n = llama_n_ctx(lctx);
1132+ // }
1133+
1134+ // if (params.sampling.dry_penalty_last_n == -1) {
1135+ // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1136+ // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1137+ // }
1138+
1139+ pimpl->samplers .resize (cparams.n_seq_max );
1140+
1141+ for (int i = 0 ; i < (int ) cparams.n_seq_max ; ++i) {
1142+ pimpl->samplers [i].reset (common_sampler_init (model, params.sampling ));
1143+ }
1144+
10861145 llama_context * lctx = llama_init_from_model (model, cparams);
1146+ if (lctx == NULL ) {
1147+ LOG_ERR (" %s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1148+ __func__, params.model .path .c_str ());
1149+ return ;
1150+ }
1151+
1152+ pimpl->context .reset (lctx);
1153+ }
1154+
1155+ llama_model * common_init_result::model () {
1156+ return pimpl->model .get ();
1157+ }
1158+
1159+ llama_context * common_init_result::context () {
1160+ return pimpl->context .get ();
1161+ }
1162+
1163+ common_sampler * common_init_result::sampler (llama_seq_id seq_id) {
1164+ return pimpl->samplers [seq_id].get ();
1165+ }
1166+
1167+ std::vector<llama_adapter_lora_ptr> & common_init_result::lora () {
1168+ return pimpl->lora ;
1169+ }
1170+
1171+ void common_init_result::free_context () {
1172+ pimpl->context .reset ();
1173+ }
1174+
1175+ common_init_result_ptr common_init_from_params (common_params & params) {
1176+ common_init_result_ptr res (new common_init_result (params));
1177+
1178+ llama_model * model = res->model ();
1179+ if (model == NULL ) {
1180+ LOG_ERR (" %s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
1181+ __func__, params.model .path .c_str ());
1182+ return res;
1183+ }
1184+
1185+ llama_context * lctx = res->context ();
10871186 if (lctx == NULL ) {
10881187 LOG_ERR (" %s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n " ,
10891188 __func__, params.model .path .c_str ());
1090- llama_model_free (model);
1091- return iparams;
1189+ return res;
10921190 }
10931191
1192+ const llama_vocab * vocab = llama_model_get_vocab (model);
1193+
10941194 if (params.ctx_shift && !llama_memory_can_shift (llama_get_memory (lctx))) {
10951195 LOG_WRN (" %s: KV cache shifting is not supported for this context, disabling KV cache shifting\n " , __func__);
10961196 params.ctx_shift = false ;
@@ -1102,10 +1202,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11021202
11031203 const auto cvec = common_control_vector_load (params.control_vectors );
11041204 if (cvec.n_embd == -1 ) {
1105- llama_free (lctx);
1106- llama_model_free (model);
1107-
1108- return iparams;
1205+ return res;
11091206 }
11101207
11111208 int err = llama_apply_adapter_cvec (
@@ -1116,10 +1213,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11161213 params.control_vector_layer_start ,
11171214 params.control_vector_layer_end );
11181215 if (err) {
1119- llama_free (lctx);
1120- llama_model_free (model);
1121-
1122- return iparams;
1216+ return res;
11231217 }
11241218 }
11251219
@@ -1143,10 +1237,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11431237 }
11441238
11451239 if (!ok) {
1146- llama_free (lctx);
1147- llama_model_free (model);
1148-
1149- return iparams;
1240+ return res;
11501241 }
11511242 }
11521243
@@ -1156,9 +1247,7 @@ struct common_init_result common_init_from_params(common_params & params) {
11561247 lora.reset (llama_adapter_lora_init (model, la.path .c_str ()));
11571248 if (lora == nullptr ) {
11581249 LOG_ERR (" %s: failed to apply lora adapter '%s'\n " , __func__, la.path .c_str ());
1159- llama_free (lctx);
1160- llama_model_free (model);
1161- return iparams;
1250+ return res;
11621251 }
11631252
11641253 char buf[1024 ];
@@ -1167,43 +1256,13 @@ struct common_init_result common_init_from_params(common_params & params) {
11671256 la.task_name = buf;
11681257 llama_adapter_meta_val_str (la.ptr , " adapter.lora.prompt_prefix" , buf, sizeof (buf));
11691258 la.prompt_prefix = buf;
1170- iparams. lora .emplace_back (std::move (lora)); // copy to list of loaded adapters
1259+ res-> lora () .emplace_back (std::move (lora)); // copy to list of loaded adapters
11711260 }
11721261
11731262 if (!params.lora_init_without_apply ) {
11741263 common_set_adapter_lora (lctx, params.lora_adapters );
11751264 }
11761265
1177- if (params.sampling .ignore_eos && llama_vocab_eos (vocab) == LLAMA_TOKEN_NULL) {
1178- LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
1179- params.sampling .ignore_eos = false ;
1180- }
1181-
1182- // initialize once
1183- for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
1184- if (llama_vocab_is_eog (vocab, i)) {
1185- LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
1186- params.sampling .logit_bias_eog .push_back ({i, -INFINITY});
1187- }
1188- }
1189-
1190- if (params.sampling .ignore_eos ) {
1191- // add EOG biases to the active set of logit biases
1192- params.sampling .logit_bias .insert (
1193- params.sampling .logit_bias .end (),
1194- params.sampling .logit_bias_eog .begin (), params.sampling .logit_bias_eog .end ());
1195- }
1196-
1197- if (params.sampling .penalty_last_n == -1 ) {
1198- LOG_INF (" %s: setting penalty_last_n to ctx_size = %d\n " , __func__, llama_n_ctx (lctx));
1199- params.sampling .penalty_last_n = llama_n_ctx (lctx);
1200- }
1201-
1202- if (params.sampling .dry_penalty_last_n == -1 ) {
1203- LOG_INF (" %s: setting dry_penalty_last_n to ctx_size = %d\n " , __func__, llama_n_ctx (lctx));
1204- params.sampling .dry_penalty_last_n = llama_n_ctx (lctx);
1205- }
1206-
12071266 if (params.warmup ) {
12081267 LOG_WRN (" %s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n " , __func__);
12091268
@@ -1282,12 +1341,11 @@ struct common_init_result common_init_from_params(common_params & params) {
12821341 }
12831342 #endif
12841343
1285- iparams.model .reset (model);
1286- iparams.context .reset (lctx);
1287-
1288- return iparams;
1344+ return res;
12891345}
12901346
1347+ common_init_result::~common_init_result () = default ;
1348+
12911349std::string get_model_endpoint () {
12921350 const char * model_endpoint_env = getenv (" MODEL_ENDPOINT" );
12931351 // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1296,7 +1354,9 @@ std::string get_model_endpoint() {
12961354 std::string model_endpoint = " https://huggingface.co/" ;
12971355 if (endpoint_env) {
12981356 model_endpoint = endpoint_env;
1299- if (model_endpoint.back () != ' /' ) model_endpoint += ' /' ;
1357+ if (model_endpoint.back () != ' /' ) {
1358+ model_endpoint += ' /' ;
1359+ }
13001360 }
13011361 return model_endpoint;
13021362}
0 commit comments