@@ -182,6 +182,57 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
182182 << " command line flag.\n " ;
183183}
184184
185+
186+ std::vector<int > tokenize (
187+ const std::string& prompt_string,
188+ const sentencepiece::SentencePieceProcessor* tokenizer) {
189+ std::string formatted = " <start_of_turn>user\n " + prompt_string +
190+ " <end_of_turn>\n <start_of_turn>model\n " ;
191+ std::vector<int > tokens;
192+ HWY_ASSERT (tokenizer->Encode (formatted, &tokens).ok ());
193+ tokens.insert (tokens.begin (), 2 ); // BOS token
194+ return tokens;
195+ }
196+
197+ int GemmaWrapper::completionPrompt (std::string& prompt) {
198+ size_t pos = 0 ; // KV Cache position
199+ size_t num_threads = static_cast <size_t >(std::clamp (
200+ static_cast <int >(std::thread::hardware_concurrency ()) - 2 , 1 , 18 ));
201+ hwy::ThreadPool pool (num_threads);
202+ // Initialize random number generator
203+ std::mt19937 gen;
204+ std::random_device rd;
205+ gen.seed (rd ());
206+
207+ // Tokenize instruction
208+ std::vector<int > tokens =
209+ tokenize (prompt, this ->m_model ->Tokenizer ());
210+ size_t ntokens = tokens.size ();
211+
212+ // This callback function gets invoked everytime a token is generated
213+ auto stream_token = [&pos, &gen, &ntokens, tokenizer = this ->m_model ->Tokenizer ()](
214+ int token, float ) {
215+ ++pos;
216+ if (pos < ntokens) {
217+ // print feedback
218+ } else if (token != gcpp::EOS_ID) {
219+ std::string token_text;
220+ HWY_ASSERT (tokenizer->Decode (std::vector<int >{token}, &token_text).ok ());
221+ std::cout << token_text << std::flush;
222+ }
223+ return true ;
224+ };
225+
226+ GenerateGemma (*this ->m_model ,
227+ {.max_tokens = 2048 ,
228+ .max_generated_tokens = 1024 ,
229+ .temperature = 1.0 ,
230+ .verbosity = 0 },
231+ tokens, /* KV cache position = */ 0 , this ->m_kvcache , pool,
232+ stream_token, gen);
233+ std::cout << std::endl;
234+ }
235+
185236void GemmaWrapper::loadModel (const std::vector<std::string> &args) {
186237 int argc = args.size () + 1 ; // +1 for the program name
187238 std::vector<char *> argv_vec;
@@ -269,5 +320,5 @@ PYBIND11_MODULE(pygemma, m) {
269320 };
270321 self.loadModel (args); // Assuming GemmaWrapper::loadModel accepts std::vector<std::string>
271322 }, py::arg (" tokenizer" ), py::arg (" compressed_weights" ), py::arg (" model" ))
272- .def (" completion" , &GemmaWrapper::completionPrompt);
323+ .def (" completion" , &GemmaWrapper::completionPrompt, " Function that completes given prompt. " );
273324}
0 commit comments