Skip to content

Commit 3d646c9

Browse files
jdluzenCopilot
authored andcommitted
perf: optimize Phase 2 batch generation with dynamic compaction by 3-12% (#20)
* perf: improve batch generation in step 1 by 3-12% * remove comments * remove comments
1 parent 89747a2 commit 3d646c9

1 file changed

Lines changed: 76 additions & 49 deletions

File tree

tools/ace-qwen3.cpp

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -529,22 +529,22 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
529529
}
530530

531531
// Batched decode loop, partial LM head: only project [TOKEN_IM_END..V)
532-
Timer t_decode;
533-
int V_eff = V - TOKEN_IM_END; // 65559 vs 217204
534-
std::vector<float> logits_cond((size_t) V_eff * N);
535-
std::vector<float> logits_uncond((size_t) V_eff * N);
536-
std::vector<int> tokens(N);
532+
Timer t_decode;
533+
int V_eff = V - TOKEN_IM_END;
537534

538-
// CFG: single forward with 2*N (cond + uncond)
539-
int N2 = use_cfg ? 2 * N : N;
540-
std::vector<int> tokens_2n(N2), sets_2n(N2);
541-
std::vector<float> logits_2n((size_t) V_eff * N2);
542-
if (use_cfg) {
543-
for (int i = 0; i < N; i++) {
544-
sets_2n[i] = cond_sets[i];
545-
sets_2n[N + i] = uncond_sets[i];
546-
}
547-
}
535+
// Pre-allocate batched arrays for the maximum possible size (N or 2*N for CFG)
536+
int max_N2 = use_cfg ? 2 * N : N;
537+
std::vector<int> batch_tokens(max_N2);
538+
std::vector<int> batch_sets(max_N2);
539+
std::vector<float> batch_logits((size_t) V_eff * max_N2);
540+
541+
// This array maps the compact "active" index back to the original sequence index (0 to N-1)
542+
std::vector<int> active_to_orig(N);
543+
544+
// Tiny array for CPU sampling (EOS token + Audio Codes) to prevent sorting 150,000 text logits
545+
int audio_code_offset = AUDIO_CODE_BASE - TOKEN_IM_END;
546+
int compact_V = AUDIO_CODE_COUNT + 1;
547+
std::vector<float> compact_logits(compact_V);
548548

549549
int n_active = N;
550550
for (int i = 0; i < N; i++) {
@@ -554,58 +554,85 @@ static std::vector<std::string> run_phase2_batch(Qwen3LM *
554554
}
555555

556556
for (int step = 0; step < max_tokens && n_active > 0; step++) {
557-
// Collect tokens (done sequences feed their last token, result ignored)
558-
for (int i = 0; i < N; i++) {
559-
tokens[i] = seqs[i].last_token;
560-
}
557+
int current_active = 0;
561558

562-
if (use_cfg) {
563-
// Single batched forward: cond[0..N-1] + uncond[N..2N-1]
564-
for (int i = 0; i < N; i++) {
565-
tokens_2n[i] = tokens[i];
566-
tokens_2n[N + i] = tokens[i];
559+
// 1. DYNAMIC COMPACTION: Loop through all N sequences, but only gather the active ones!
560+
for (int i = 0; i < N; i++) {
561+
if (!seqs[i].done) {
562+
active_to_orig[current_active] = i; // Remember that this slot belongs to sequence 'i'
563+
564+
if (use_cfg) {
565+
// Place the Cond token/set in the first half
566+
batch_tokens[current_active] = seqs[i].last_token;
567+
batch_sets[current_active] = cond_sets[i];
568+
569+
// Place the Uncond token/set exactly n_active elements later
570+
batch_tokens[n_active + current_active] = seqs[i].last_token;
571+
batch_sets[n_active + current_active] = uncond_sets[i];
572+
} else {
573+
batch_tokens[current_active] = seqs[i].last_token;
574+
batch_sets[current_active] = cond_sets[i];
575+
}
576+
current_active++;
567577
}
568-
qw3lm_forward_batch(m, tokens_2n.data(), sets_2n.data(), N2, logits_2n.data(), TOKEN_IM_END, V_eff);
569-
memcpy(logits_cond.data(), logits_2n.data(), (size_t) V_eff * N * sizeof(float));
570-
memcpy(logits_uncond.data(), logits_2n.data() + (size_t) V_eff * N, (size_t) V_eff * N * sizeof(float));
571-
} else {
572-
qw3lm_forward_batch(m, tokens.data(), cond_sets.data(), N, logits_cond.data(), TOKEN_IM_END, V_eff);
573578
}
574579

575-
// Per-sequence: CFG combine + sample (logits are [V_eff] starting at TOKEN_IM_END)
576-
for (int i = 0; i < N; i++) {
577-
if (seqs[i].done) {
578-
continue;
579-
}
580+
// 2. FORWARD PASS: GPU only computes attention for n_active sequences
581+
int actual_batch_size = use_cfg ? (2 * n_active) : n_active;
582+
qw3lm_forward_batch(m, batch_tokens.data(), batch_sets.data(), actual_batch_size, batch_logits.data(),
583+
TOKEN_IM_END, V_eff);
584+
585+
// 3. TARGETED CFG & LOGIT EXTRACTION
586+
for (int a = 0; a < n_active; a++) {
587+
int orig_i = active_to_orig[a]; // Map back to original sequence object
588+
589+
// Pointer to the conditional logits for THIS active sequence
590+
float * lc = batch_logits.data() + (size_t) a * V_eff;
580591

581-
float * lc = logits_cond.data() + (size_t) i * V_eff;
582592
if (use_cfg) {
583-
float * lu = logits_uncond.data() + (size_t) i * V_eff;
584-
for (int v = 0; v < V_eff; v++) {
585-
lc[v] = lu[v] + cfg_scale * (lc[v] - lu[v]);
593+
// Pointer to the unconditional logits (offset by n_active)
594+
float * lu = batch_logits.data() + (size_t) (n_active + a) * V_eff;
595+
596+
// Targeted CFG Math: Only apply it to EOS + Audio Codes. Skip the 150,000 text tokens!
597+
lc[0] = lu[0] + cfg_scale * (lc[0] - lu[0]); // EOS token
598+
for (int c = 0; c < AUDIO_CODE_COUNT; c++) {
599+
int idx = audio_code_offset + c;
600+
lc[idx] = lu[idx] + cfg_scale * (lc[idx] - lu[idx]);
586601
}
587602
}
588603

589-
// Mask the 24-token gap: indices 1..AUDIO_CODE_BASE-TOKEN_IM_END-1
590-
// (index 0 = TOKEN_IM_END = EOS, index 24+ = audio codes)
591-
for (int v = 1; v < AUDIO_CODE_BASE - TOKEN_IM_END; v++) {
592-
lc[v] = -1e9f;
604+
// Extract ONLY the valid target tokens into the tiny compact array
605+
compact_logits[0] = lc[0];
606+
for (int c = 0; c < AUDIO_CODE_COUNT; c++) {
607+
compact_logits[c + 1] = lc[audio_code_offset + c];
593608
}
594-
int tok = sample_top_k_p(lc, V_eff, temperature, top_p, top_k, seqs[i].rng) + TOKEN_IM_END;
595-
seqs[i].last_token = tok;
609+
610+
// CPU samples instantly because it only has to sort ~2049 items instead of 150,000+
611+
int compact_tok =
612+
sample_top_k_p(compact_logits.data(), compact_V, temperature, top_p, top_k, seqs[orig_i].rng);
613+
614+
// Map the sampled index back to global vocabulary ID
615+
int tok = (compact_tok == 0) ? TOKEN_IM_END : (AUDIO_CODE_BASE + compact_tok - 1);
616+
617+
seqs[orig_i].last_token = tok;
596618

597619
if (tok == TOKEN_IM_END) {
598-
seqs[i].done = true;
599-
n_active--;
600-
} else if (tok >= AUDIO_CODE_BASE && tok < AUDIO_CODE_BASE + AUDIO_CODE_COUNT) {
601-
seqs[i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
620+
seqs[orig_i].done = true;
621+
} else {
622+
seqs[orig_i].audio_codes.push_back(tok - AUDIO_CODE_BASE);
602623
}
603624
}
604625

605-
int total_codes = 0;
626+
// 4. UPDATE ACTIVE COUNT for the next loop iteration
627+
int next_active_count = 0;
628+
int total_codes = 0;
606629
for (int i = 0; i < N; i++) {
630+
if (!seqs[i].done) {
631+
next_active_count++;
632+
}
607633
total_codes += (int) seqs[i].audio_codes.size();
608634
}
635+
n_active = next_active_count;
609636

610637
if ((step + 1) % 50 == 0) {
611638
double elapsed = t_decode.ms() / 1000.0;

0 commit comments

Comments
 (0)