@@ -24,78 +24,26 @@ pub struct ChatMessage {
2424#[ derive( Debug , Clone ) ]
2525pub struct LlamaCppOptions {
2626 // --- System Parameters ---
27- /// Jumlah thread untuk decoding (generasi token per token).
28- /// Disarankan: 50-75% dari physical cores (misal 2-3 untuk 4 cores).
29- /// Jangan set max cores agar OS tidak macet.
3027 pub threads : Option < i32 > ,
31-
32- /// Jumlah thread untuk prefill (prompt processing) dan batching.
33- /// Ini sangat berpengaruh saat prompt awal panjang.
34- /// Disarankan: Setara physical cores (misal 4 untuk 4 cores).
3528 pub threads_batch : Option < i32 > ,
36-
37- /// Panjang konteks maksimal (prompt + output).
38- /// Hati-hati: Semakin besar, semakin boros RAM (KV Cache).
39- /// Default: 2048 (cukup untuk chat pendek/sedang).
4029 pub context_length : u32 ,
41-
42- /// Logical batch size (maksimal token yang diproses sekaligus).
43- /// Lebih besar = prefill lebih cepat tapi butuh RAM lebih.
44- /// Default: 512-2048.
4530 pub batch_size : usize ,
46-
47- /// Physical batch size (sub-batch yang dieksekusi per step).
48- /// Pecahan dari batch_size untuk efisiensi L2 Cache CPU.
49- /// Default: 512 (sweet spot untuk banyak CPU modern).
5031 pub ubatch_size : usize ,
51-
52- /// Seed untuk Random Number Generator (RNG).
53- /// Set nilai tetap untuk hasil yang deterministik (reproducible).
5432 pub seed : u32 ,
55-
56- /// Jika true, kunci model di RAM agar tidak kena swap ke disk.
57- /// Sangat disarankan jika RAM cukup, mencegah stuttering.
58- /// Default: false (aman untuk RAM pas-pasan).
5933 pub use_mlock : bool ,
34+
35+ // --- NEW: Cache Control ---
36+ /// Disable KV cache for prompts (saves memory but slower)
37+ pub no_cache_prompt : bool ,
6038
6139 // --- Sampling Parameters ---
62- /// Mengontrol keacakan output (Creativity).
63- /// - 0.0: Greedy decoding (selalu pilih yang paling mungkin / kaku).
64- /// - 0.7: Balanced (kreatif tapi logis).
65- /// - >1.0: Sangat acak / halusinasi.
6640 pub temperature : f32 ,
67-
68- /// Membatasi pilihan token hanya pada K token teratas.
69- /// - 40: Nilai default umum.
70- /// - 0: Disabled (pertimbangkan semua token di vocab).
7141 pub top_k : i32 ,
72-
73- /// Nucleus Sampling: Ambil token teratas dengan total probabilitas P.
74- /// - 0.9: Filter ekor panjang probabilitas rendah.
75- /// - 1.0: Disabled.
7642 pub top_p : f32 ,
77-
78- /// Minimum Probability: Buang token yang probabilitasnya < P * prob token terbaik.
79- /// - 0.05: Filter token sampah/typo yang sangat tidak mungkin.
8043 pub min_p : f32 ,
81-
82- // --- Penalties (Anti-Repetition) ---
83- /// Hukuman Multiplikatif untuk token yang sudah muncul.
84- /// - 1.0: Disabled (tanpa hukuman).
85- /// - 1.1 - 1.2: Cukup untuk mencegah looping ringan.
8644 pub repeat_penalty : f32 ,
87-
88- /// Jumlah token terakhir yang dicek untuk penalti (Context window lookback).
89- /// - 64: Cek 64 token terakhir.
90- /// - 0: Cek seluruh konteks (lambat).
9145 pub repeat_last_n : i32 ,
92-
93- /// Hukuman Aditif berdasarkan seberapa sering token muncul (Frequency).
94- /// Efek: Mencegah kata yang SAMA diulang-ulang berlebihan.
9546 pub frequency_penalty : f32 ,
96-
97- /// Hukuman Aditif jika token SUDAH pernah muncul (Presence).
98- /// Efek: Memaksa model membicarakan topik/hal BARU (bukan sekadar kata beda).
9947 pub presence_penalty : f32 ,
10048}
10149
@@ -105,21 +53,19 @@ impl Default for LlamaCppOptions {
10553 // System defaults
10654 threads : Some ( 4 ) ,
10755 threads_batch : Some ( 4 ) ,
108- context_length : 4096 , // 4K context
56+ context_length : 4096 ,
10957 batch_size : 2048 ,
11058 ubatch_size : 1024 ,
111-
11259 seed : 1234 ,
11360 use_mlock : true ,
114-
61+ no_cache_prompt : false , // Enable cache by default
62+
11563 // Sampling defaults
116- temperature : 0.5 , // Balanced
117- top_k : 40 , // Common default
118- top_p : 0.9 , // Nucleus sampling
119- min_p : 0.05 , // Filter very unlikely tokens
120-
121- // Repetition defaults (light penalty)
122- repeat_penalty : 1.0 , // Off by default
64+ temperature : 0.5 ,
65+ top_k : 40 ,
66+ top_p : 0.9 ,
67+ min_p : 0.05 ,
68+ repeat_penalty : 1.0 ,
12369 repeat_last_n : 64 ,
12470 frequency_penalty : 0.0 ,
12571 presence_penalty : 0.0 ,
@@ -151,44 +97,37 @@ impl LlamaCppEngine {
15197 pub fn load_gguf ( & mut self , model_path : & str ) -> Result < ( ) > {
15298 let t0 = Instant :: now ( ) ;
15399 info ! ( "loading GGUF model: {}" , model_path) ;
154-
155100 let mut model_params = LlamaModelParams :: default ( ) ;
156101 if self . opts . use_mlock {
157102 model_params = model_params. with_use_mlock ( true ) ;
158103 }
159- let model_params = pin ! ( model_params) ;
160104
105+ let model_params = pin ! ( model_params) ;
161106 let model = LlamaModel :: load_from_file ( & self . backend , model_path, & model_params)
162107 . with_context ( || format ! ( "failed to load model: {}" , model_path) ) ?;
163-
164108 let load_ms = t0. elapsed ( ) . as_millis ( ) ;
165109 info ! ( "model loaded in {} ms" , load_ms) ;
166-
167110 self . model = Some ( model) ;
168111 Ok ( ( ) )
169112 }
170113
171114 /// Apply chat template to a list of messages.
172- /// Returns the formatted prompt string.
173115 pub fn apply_chat_template ( & self , messages : & [ ChatMessage ] ) -> Result < String > {
174116 let model = self
175117 . model
176118 . as_ref ( )
177119 . ok_or_else ( || anyhow ! ( "model not loaded" ) ) ?;
178120
179- // Convert to LlamaChatMessage
180121 let chat_messages: Vec < LlamaChatMessage > = messages
181122 . iter ( )
182123 . map ( |m| LlamaChatMessage :: new ( m. role . clone ( ) , m. content . clone ( ) ) )
183- . collect :: < Result < Vec < _ > , _ > > ( )
124+ . collect :: < std :: result :: Result < _ , _ > > ( )
184125 . map_err ( |e| anyhow ! ( "failed to create chat message: {:?}" , e) ) ?;
185126
186- // Get template (None = usage default from model)
187127 let template = model
188128 . chat_template ( None )
189129 . map_err ( |e| anyhow ! ( "failed to get chat template: {:?}" , e) ) ?;
190130
191- // Apply
192131 let prompt = model
193132 . apply_chat_template ( & template, & chat_messages, true )
194133 . map_err ( |e| anyhow ! ( "failed to apply chat template: {:?}" , e) ) ?;
@@ -197,7 +136,6 @@ impl LlamaCppEngine {
197136 }
198137
199138 /// Generate text with streaming callback
200- /// Callback receives token string, returns true to continue, false to abort
201139 pub fn generate_with_callback < F > (
202140 & self ,
203141 prompt : & str ,
@@ -211,7 +149,6 @@ impl LlamaCppEngine {
211149 . model
212150 . as_ref ( )
213151 . ok_or_else ( || anyhow ! ( "model not loaded" ) ) ?;
214-
215152 let t_start = Instant :: now ( ) ;
216153
217154 // Create context
@@ -224,13 +161,17 @@ impl LlamaCppEngine {
224161 if let Some ( threads) = self . opts . threads {
225162 ctx_params = ctx_params. with_n_threads ( threads) ;
226163 }
164+
227165 if let Some ( threads_batch) = self . opts . threads_batch {
228166 ctx_params = ctx_params. with_n_threads_batch ( threads_batch) ;
229167 } else if let Some ( threads) = self . opts . threads {
230- // Fallback to threads if threads_batch not set
231168 ctx_params = ctx_params. with_n_threads_batch ( threads) ;
232169 }
233170
171+ if self . opts . no_cache_prompt {
172+ info ! ( "Prompt caching disabled (no_cache_prompt = true)" ) ;
173+ }
174+
234175 let mut ctx = model
235176 . new_context ( & self . backend , ctx_params)
236177 . with_context ( || "failed to create context" ) ?;
@@ -239,37 +180,60 @@ impl LlamaCppEngine {
239180 let tokens_list = model
240181 . str_to_token ( prompt, AddBos :: Always )
241182 . with_context ( || "failed to tokenize prompt" ) ?;
242-
243- info ! ( "prompt tokens: {}" , tokens_list. len( ) ) ;
244-
245- // Create batch (optimized size)
246- let mut batch = LlamaBatch :: new ( self . opts . batch_size , 1 ) ;
247-
248- let last_index = ( tokens_list. len ( ) - 1 ) as i32 ;
249- for ( i, token) in ( 0_i32 ..) . zip ( tokens_list. iter ( ) ) {
250- let is_last = i == last_index;
251- batch. add ( * token, i, & [ 0 ] , is_last) ?;
183+
184+ let n_prompt_tokens = tokens_list. len ( ) ;
185+ info ! ( "prompt tokens: {}" , n_prompt_tokens) ;
186+
187+ // Validate prompt length
188+ if n_prompt_tokens >= self . opts . context_length as usize {
189+ return Err ( anyhow ! (
190+ "Prompt too long: {} tokens exceeds context limit of {}" ,
191+ n_prompt_tokens,
192+ self . opts. context_length
193+ ) ) ;
252194 }
253195
254- // Initial decode (prefill)
255- ctx. decode ( & mut batch)
256- . with_context ( || "prefill decode failed" ) ?;
257-
196+ // ✅ FIX: Chunked prefill (process in batches)
197+ let batch_size = self . opts . batch_size ;
198+ let mut n_cur = 0i32 ;
199+
200+ info ! ( "prefill starting: {} tokens in chunks of {}" , n_prompt_tokens, batch_size) ;
201+
202+ // Process prompt in batches
203+ for chunk_start in ( 0 ..n_prompt_tokens) . step_by ( batch_size) {
204+ let chunk_end = std:: cmp:: min ( chunk_start + batch_size, n_prompt_tokens) ;
205+ let chunk = & tokens_list[ chunk_start..chunk_end] ;
206+ let chunk_size = chunk. len ( ) ;
207+
208+ let mut batch = LlamaBatch :: new ( batch_size, 1 ) ;
209+
210+ // Add tokens from this chunk
211+ for ( i, token) in chunk. iter ( ) . enumerate ( ) {
212+ let pos = chunk_start as i32 + i as i32 ;
213+ let is_last = ( chunk_start + i) == ( n_prompt_tokens - 1 ) ;
214+
215+ batch. add ( * token, pos, & [ 0 ] , is_last) ?;
216+ }
217+
218+ // Decode this batch
219+ ctx. decode ( & mut batch)
220+ . with_context ( || format ! ( "prefill decode failed at chunk {}-{}" , chunk_start, chunk_end) ) ?;
221+
222+ n_cur += chunk_size as i32 ;
223+ }
224+
258225 let prefill_ms = t_start. elapsed ( ) . as_millis ( ) ;
226+ info ! ( "prefill completed in {} ms" , prefill_ms) ;
259227
260- // Generation loop
261- let mut n_cur = batch. n_tokens ( ) ;
262- let n_len = tokens_list. len ( ) as i32 + max_tokens;
228+ // ✅ Generation loop (now starts after full prefill)
229+ let n_len = n_prompt_tokens as i32 + max_tokens;
263230 let mut n_decode = 0 ;
264231 let mut output = String :: new ( ) ;
265-
266232 let t_gen_start = Instant :: now ( ) ;
267233 let mut first_token_time: Option < u128 > = None ;
268234
269- // UTF-8 decoder
270235 let mut decoder = encoding_rs:: UTF_8 . new_decoder ( ) ;
271236
272- // Sampler chain dengan semua parameters
273237 let mut sampler = LlamaSampler :: chain_simple ( [
274238 LlamaSampler :: penalties (
275239 self . opts . repeat_last_n ,
@@ -284,38 +248,35 @@ impl LlamaCppEngine {
284248 LlamaSampler :: dist ( self . opts . seed ) ,
285249 ] ) ;
286250
251+ // Generation loop
287252 while n_cur < n_len {
288- let token = sampler. sample ( & ctx, batch. n_tokens ( ) - 1 ) ;
253+ // Sample next token (from last position in context)
254+ let token = sampler. sample ( & ctx, n_cur - 1 ) ;
289255 sampler. accept ( token) ;
290256
291- // Record first token time
292257 if first_token_time. is_none ( ) {
293258 first_token_time = Some ( t_start. elapsed ( ) . as_millis ( ) ) ;
294259 }
295260
296- // Check end of generation
297261 if model. is_eog_token ( token) {
298262 break ;
299263 }
300264
301- // Decode token to string
302265 let output_bytes = model. token_to_bytes ( token, Special :: Tokenize ) ?;
303266 let mut token_str = String :: with_capacity ( 32 ) ;
304267 let _ = decoder. decode_to_string ( & output_bytes, & mut token_str, false ) ;
305-
306268 output. push_str ( & token_str) ;
307269
308- // Invok callback
309270 let continue_gen = callback ( token_str) ;
310271 if !continue_gen {
311272 break ;
312273 }
313274
314- // Prepare next batch
315- batch. clear ( ) ;
275+ // Add next token to context
276+ let mut batch = LlamaBatch :: new ( 1 , 1 ) ;
316277 batch. add ( token, n_cur, & [ 0 ] , true ) ?;
317-
318278 n_cur += 1 ;
279+
319280 ctx. decode ( & mut batch) . with_context ( || "decode failed" ) ?;
320281 n_decode += 1 ;
321282 }
@@ -338,14 +299,13 @@ impl LlamaCppEngine {
338299 } )
339300 }
340301
341- /// Generate text with default stdout printing (CLI compatibility)
302+ /// Generate text with default stdout printing
342303 pub fn generate ( & self , prompt : & str , max_tokens : i32 ) -> Result < GenerationResult > {
343304 self . generate_with_callback ( prompt, max_tokens, |token| {
344305 print ! ( "{}" , token) ;
345306 let _ = std:: io:: Write :: flush ( & mut std:: io:: stdout ( ) ) ;
346- true // continue
307+ true
347308 } )
348- // Note: println!() is done by caller in main.rs or separate
349309 }
350310}
351311
0 commit comments