@@ -529,22 +529,22 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
529529 }
530530
531531 // Batched decode loop, partial LM head: only project [TOKEN_IM_END..V)
532- Timer t_decode;
533- int V_eff = V - TOKEN_IM_END; // 65559 vs 217204
534- std::vector<float > logits_cond ((size_t ) V_eff * N);
535- std::vector<float > logits_uncond ((size_t ) V_eff * N);
536- std::vector<int > tokens (N);
532+ Timer t_decode;
533+ int V_eff = V - TOKEN_IM_END;
537534
538- // CFG: single forward with 2*N (cond + uncond)
539- int N2 = use_cfg ? 2 * N : N;
540- std::vector<int > tokens_2n (N2), sets_2n (N2);
541- std::vector<float > logits_2n ((size_t ) V_eff * N2);
542- if (use_cfg) {
543- for (int i = 0 ; i < N; i++) {
544- sets_2n[i] = cond_sets[i];
545- sets_2n[N + i] = uncond_sets[i];
546- }
547- }
535+ // Pre-allocate batched arrays for the maximum possible size (N or 2*N for CFG)
536+ int max_N2 = use_cfg ? 2 * N : N;
537+ std::vector<int > batch_tokens (max_N2);
538+ std::vector<int > batch_sets (max_N2);
539+ std::vector<float > batch_logits ((size_t ) V_eff * max_N2);
540+
541+ // This array maps the compact "active" index back to the original sequence index (0 to N-1)
542+ std::vector<int > active_to_orig (N);
543+
544+ // Tiny array for CPU sampling (EOS token + Audio Codes) to prevent sorting 150,000 text logits
545+ int audio_code_offset = AUDIO_CODE_BASE - TOKEN_IM_END;
546+ int compact_V = AUDIO_CODE_COUNT + 1 ;
547+ std::vector<float > compact_logits (compact_V);
548548
549549 int n_active = N;
550550 for (int i = 0 ; i < N; i++) {
@@ -554,58 +554,85 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
554554 }
555555
556556 for (int step = 0 ; step < max_tokens && n_active > 0 ; step++) {
557- // Collect tokens (done sequences feed their last token, result ignored)
558- for (int i = 0 ; i < N; i++) {
559- tokens[i] = seqs[i].last_token ;
560- }
557+ int current_active = 0 ;
561558
562- if (use_cfg) {
563- // Single batched forward: cond[0..N-1] + uncond[N..2N-1]
564- for (int i = 0 ; i < N; i++) {
565- tokens_2n[i] = tokens[i];
566- tokens_2n[N + i] = tokens[i];
559+ // 1. DYNAMIC COMPACTION: Loop through all N sequences, but only gather the active ones!
560+ for (int i = 0 ; i < N; i++) {
561+ if (!seqs[i].done ) {
562+ active_to_orig[current_active] = i; // Remember that this slot belongs to sequence 'i'
563+
564+ if (use_cfg) {
565+ // Place the Cond token/set in the first half
566+ batch_tokens[current_active] = seqs[i].last_token ;
567+ batch_sets[current_active] = cond_sets[i];
568+
569+ // Place the Uncond token/set exactly n_active elements later
570+ batch_tokens[n_active + current_active] = seqs[i].last_token ;
571+ batch_sets[n_active + current_active] = uncond_sets[i];
572+ } else {
573+ batch_tokens[current_active] = seqs[i].last_token ;
574+ batch_sets[current_active] = cond_sets[i];
575+ }
576+ current_active++;
567577 }
568- qw3lm_forward_batch (m, tokens_2n.data (), sets_2n.data (), N2, logits_2n.data (), TOKEN_IM_END, V_eff);
569- memcpy (logits_cond.data (), logits_2n.data (), (size_t ) V_eff * N * sizeof (float ));
570- memcpy (logits_uncond.data (), logits_2n.data () + (size_t ) V_eff * N, (size_t ) V_eff * N * sizeof (float ));
571- } else {
572- qw3lm_forward_batch (m, tokens.data (), cond_sets.data (), N, logits_cond.data (), TOKEN_IM_END, V_eff);
573578 }
574579
575- // Per-sequence: CFG combine + sample (logits are [V_eff] starting at TOKEN_IM_END)
576- for (int i = 0 ; i < N; i++) {
577- if (seqs[i].done ) {
578- continue ;
579- }
580+ // 2. FORWARD PASS: GPU only computes attention for n_active sequences
581+ int actual_batch_size = use_cfg ? (2 * n_active) : n_active;
582+ qw3lm_forward_batch (m, batch_tokens.data (), batch_sets.data (), actual_batch_size, batch_logits.data (),
583+ TOKEN_IM_END, V_eff);
584+
585+ // 3. TARGETED CFG & LOGIT EXTRACTION
586+ for (int a = 0 ; a < n_active; a++) {
587+ int orig_i = active_to_orig[a]; // Map back to original sequence object
588+
589+ // Pointer to the conditional logits for THIS active sequence
590+ float * lc = batch_logits.data () + (size_t ) a * V_eff;
580591
581- float * lc = logits_cond.data () + (size_t ) i * V_eff;
582592 if (use_cfg) {
583- float * lu = logits_uncond.data () + (size_t ) i * V_eff;
584- for (int v = 0 ; v < V_eff; v++) {
585- lc[v] = lu[v] + cfg_scale * (lc[v] - lu[v]);
593+ // Pointer to the unconditional logits (offset by n_active)
594+ float * lu = batch_logits.data () + (size_t ) (n_active + a) * V_eff;
595+
596+ // Targeted CFG Math: Only apply it to EOS + Audio Codes. Skip the 150,000 text tokens!
597+ lc[0 ] = lu[0 ] + cfg_scale * (lc[0 ] - lu[0 ]); // EOS token
598+ for (int c = 0 ; c < AUDIO_CODE_COUNT; c++) {
599+ int idx = audio_code_offset + c;
600+ lc[idx] = lu[idx] + cfg_scale * (lc[idx] - lu[idx]);
586601 }
587602 }
588603
589- // Mask the 24-token gap: indices 1..AUDIO_CODE_BASE-TOKEN_IM_END-1
590- // (index 0 = TOKEN_IM_END = EOS, index 24+ = audio codes)
591- for (int v = 1 ; v < AUDIO_CODE_BASE - TOKEN_IM_END; v ++) {
592- lc[v ] = -1e9f ;
604+ // Extract ONLY the valid target tokens into the tiny compact array
605+ compact_logits[ 0 ] = lc[ 0 ];
606+ for (int c = 0 ; c < AUDIO_CODE_COUNT; c ++) {
607+ compact_logits[c + 1 ] = lc[audio_code_offset + c] ;
593608 }
594- int tok = sample_top_k_p (lc, V_eff, temperature, top_p, top_k, seqs[i].rng ) + TOKEN_IM_END;
595- seqs[i].last_token = tok;
609+
610+ // CPU samples instantly because it only has to sort ~2049 items instead of 150,000+
611+ int compact_tok =
612+ sample_top_k_p (compact_logits.data (), compact_V, temperature, top_p, top_k, seqs[orig_i].rng );
613+
614+ // Map the sampled index back to global vocabulary ID
615+ int tok = (compact_tok == 0 ) ? TOKEN_IM_END : (AUDIO_CODE_BASE + compact_tok - 1 );
616+
617+ seqs[orig_i].last_token = tok;
596618
597619 if (tok == TOKEN_IM_END) {
598- seqs[i].done = true ;
599- n_active--;
600- } else if (tok >= AUDIO_CODE_BASE && tok < AUDIO_CODE_BASE + AUDIO_CODE_COUNT) {
601- seqs[i].audio_codes .push_back (tok - AUDIO_CODE_BASE);
620+ seqs[orig_i].done = true ;
621+ } else {
622+ seqs[orig_i].audio_codes .push_back (tok - AUDIO_CODE_BASE);
602623 }
603624 }
604625
605- int total_codes = 0 ;
626+ // 4. UPDATE ACTIVE COUNT for the next loop iteration
627+ int next_active_count = 0 ;
628+ int total_codes = 0 ;
606629 for (int i = 0 ; i < N; i++) {
630+ if (!seqs[i].done ) {
631+ next_active_count++;
632+ }
607633 total_codes += (int ) seqs[i].audio_codes .size ();
608634 }
635+ n_active = next_active_count;
609636
610637 if ((step + 1 ) % 50 == 0 ) {
611638 double elapsed = t_decode.ms () / 1000.0 ;
0 commit comments