@@ -189,15 +189,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
189189 static_cast <int32_t >(state_.flatten_tokens_vec .size ()) -
190190 static_cast <int32_t >(state.flatten_tokens_vec .size ());
191191 for (const auto & idx : state.selected_token_idxes ) {
192- state_.selected_token_idxes .push_back (idx + selected_token_idxes_offset);
192+ state_.selected_token_idxes .emplace_back (idx +
193+ selected_token_idxes_offset);
193194 }
194195 state_.sampling_params .insert (state_.sampling_params .end (),
195196 state.sampling_params .begin (),
196197 state.sampling_params .end ());
197198 int32_t sample_idxes_offset =
198199 static_cast <int32_t >(state_.sample_idxes .size ());
199200 for (const auto & idx : state.sample_idxes ) {
200- state_.sample_idxes .push_back (idx + sample_idxes_offset);
201+ state_.sample_idxes .emplace_back (idx + sample_idxes_offset);
201202 }
202203 state_.unique_token_ids_vec .insert (state_.unique_token_ids_vec .end (),
203204 state.unique_token_ids_vec .begin (),
@@ -217,15 +218,18 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
217218 state_.q_seq_lens .insert (state_.q_seq_lens .end (),
218219 state.q_seq_lens .begin (),
219220 state.q_seq_lens .end ());
221+ state_.kv_cache_tokens_nums .insert (state_.kv_cache_tokens_nums .end (),
222+ state.kv_cache_tokens_nums .begin (),
223+ state.kv_cache_tokens_nums .end ());
220224#elif defined(USE_MLU)
221225 int32_t seq_len_offset = state_.seq_lens .back ();
222226 // skip the first element which is 0
223227 for (size_t i = 1 ; i < state.seq_lens .size (); ++i) {
224- state_.seq_lens .push_back (state.seq_lens [i] + seq_len_offset);
228+ state_.seq_lens .emplace_back (state.seq_lens [i] + seq_len_offset);
225229 }
226230 int32_t q_seq_len_offset = state_.q_seq_lens .back ();
227231 for (size_t i = 1 ; i < state.q_seq_lens .size (); ++i) {
228- state_.q_seq_lens .push_back (state.q_seq_lens [i] + q_seq_len_offset);
232+ state_.q_seq_lens .emplace_back (state.q_seq_lens [i] + q_seq_len_offset);
229233 }
230234#endif
231235 state_.new_token_slot_ids .insert (state_.new_token_slot_ids .end (),
@@ -286,12 +290,13 @@ void BatchInputBuilder::process_single_sequence(
286290 state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0 );
287291 state.max_seq_len = std::max (state.max_seq_len , seq_len);
288292 state.q_max_seq_len = std::max (state.q_max_seq_len , q_seq_len);
293+ state.kv_cache_tokens_nums .emplace_back (n_kv_cache_tokens);
289294#if defined(USE_NPU)
290- state.seq_lens .push_back (seq_len);
291- state.q_seq_lens .push_back (q_seq_len);
295+ state.seq_lens .emplace_back (seq_len);
296+ state.q_seq_lens .emplace_back (q_seq_len);
292297#elif defined(USE_MLU)
293- state.seq_lens .push_back (state.seq_lens .back () + seq_len);
294- state.q_seq_lens .push_back (state.q_seq_lens .back () + q_seq_len);
298+ state.seq_lens .emplace_back (state.seq_lens .back () + seq_len);
299+ state.q_seq_lens .emplace_back (state.q_seq_lens .back () + q_seq_len);
295300#endif
296301 // Process tokens and positions
297302 extract_tokens_and_positions (sequence, n_kv_cache_tokens, seq_len, state_ptr);
@@ -317,8 +322,8 @@ void BatchInputBuilder::process_single_sequence(
317322 // Input for beam search kernel
318323 if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search () &&
319324 sequence->num_generated_tokens () > 0 ) {
320- state.acc_logprob_vec .push_back (sequence->get_average_logprob () *
321- sequence->num_generated_tokens ());
325+ state.acc_logprob_vec .emplace_back (sequence->get_average_logprob () *
326+ sequence->num_generated_tokens ());
322327 }
323328}
324329
@@ -343,15 +348,15 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
343348 if (use_mrope_) {
344349 const auto & args = *args_;
345350 MPositionHelper helper (*sequence, args);
346- state.mrope_positions_vec .push_back (helper.get_positions ());
351+ state.mrope_positions_vec .emplace_back (helper.get_positions ());
347352 }
348353
349354 // Process each token
350355 for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) {
351- state.flatten_tokens_vec .push_back (token_ids[j]);
356+ state.flatten_tokens_vec .emplace_back (token_ids[j]);
352357
353358 if (!use_mrope_) {
354- state.flatten_positions_vec .push_back (static_cast <int32_t >(j));
359+ state.flatten_positions_vec .emplace_back (static_cast <int32_t >(j));
355360 }
356361
357362 // Handle sampling for last tokens
@@ -365,10 +370,10 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
365370 if (n_tokens == seq_len) {
366371 // last chunk of prefill and decode
367372 // add -1 as extra token id
368- state.extra_token_ids .push_back (-1 );
369- state.embedding_ids .push_back (sequence->get_embedding_id ());
373+ state.extra_token_ids .emplace_back (-1 );
374+ state.embedding_ids .emplace_back (sequence->get_embedding_id ());
370375 } else {
371- state.extra_token_ids .push_back (token_ids[seq_len]);
376+ state.extra_token_ids .emplace_back (token_ids[seq_len]);
372377 }
373378}
374379
@@ -387,8 +392,8 @@ void BatchInputBuilder::handle_sampling_parameters(
387392 --adjusted_token_to_count_map[token_id];
388393
389394 // Select token for sampling
390- state.selected_token_idxes .push_back (state.flatten_tokens_vec .size () - 1 );
391- state.sampling_params .push_back (sequence->sampling_param ());
395+ state.selected_token_idxes .emplace_back (state.flatten_tokens_vec .size () - 1 );
396+ state.sampling_params .emplace_back (sequence->sampling_param ());
392397
393398 // Process unique tokens
394399 const auto & seq_token_counts = sequence->token_to_count_map ();
@@ -404,19 +409,19 @@ void BatchInputBuilder::handle_sampling_parameters(
404409 (it != adjusted_token_to_count_map.end ()) ? it->second : 0 ;
405410
406411 if (count > adjust_count) {
407- ids.push_back (token_id);
408- counts.push_back (count - adjust_count);
412+ ids.emplace_back (token_id);
413+ counts.emplace_back (count - adjust_count);
409414 }
410415 }
411416
412- state.unique_token_lens_vec .push_back (static_cast <int32_t >(ids.size ()));
417+ state.unique_token_lens_vec .emplace_back (static_cast <int32_t >(ids.size ()));
413418
414419 // Mark sample token if it's the last token
415420 // TODO add test
416421 // in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000,
417422 // n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384
418423 if (token_position == seq_len - 1 ) {
419- state.sample_idxes .push_back (
424+ state.sample_idxes .emplace_back (
420425 static_cast <int32_t >(state.selected_token_idxes .size () - 1 ));
421426 }
422427}
@@ -447,7 +452,7 @@ void BatchInputBuilder::setup_kv_cache_info(
447452 int32_t block_size = 0 ;
448453 for (const auto & block : blocks) {
449454 block_size = block.size ();
450- block_ids.push_back (block.id ());
455+ block_ids.emplace_back (block.id ());
451456 u_block_ids.emplace_back (block.id ());
452457 }
453458
@@ -483,13 +488,13 @@ void BatchInputBuilder::setup_continuous_kv_cache_info(
483488 std::vector<int64_t > cache_slot_offsets;
484489 cache_slot_offsets.reserve (seq_len - n_kv_cache_tokens);
485490 for (int32_t i = n_kv_cache_tokens; i < seq_len; ++i) {
486- cache_slot_offsets.push_back (kv_cache_start_offset +
487- i * FLAGS_cache_size_per_token);
491+ cache_slot_offsets.emplace_back (kv_cache_start_offset +
492+ i * FLAGS_cache_size_per_token);
488493 }
489494 state.new_cache_slot_offsets .insert (state.new_cache_slot_offsets .end (),
490495 cache_slot_offsets.begin (),
491496 cache_slot_offsets.end ());
492- state.kv_cache_start_offsets .push_back (kv_cache_start_offset);
497+ state.kv_cache_start_offsets .emplace_back (kv_cache_start_offset);
493498}
494499
495500void BatchInputBuilder::padding_decode_batch_size (
@@ -506,22 +511,23 @@ void BatchInputBuilder::padding_decode_batch_size(
506511 // add padding tokens to the batch
507512 for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) {
508513 for (int32_t k = 0 ; k < num_decoding_tokens; ++k) {
509- state_.flatten_tokens_vec .push_back (0 );
514+ state_.flatten_tokens_vec .emplace_back (0 );
510515 if (!use_mrope_) {
511- state_.flatten_positions_vec .push_back (0 );
516+ state_.flatten_positions_vec .emplace_back (0 );
512517 } else {
513- state_.mrope_positions_vec .push_back (
518+ state_.mrope_positions_vec .emplace_back (
514519 torch::zeros ({3 , 1 }, torch::kInt ));
515520 }
516- state_.new_token_slot_ids .push_back (0 );
521+ state_.new_token_slot_ids .emplace_back (0 );
517522 }
518523#if defined(USE_NPU)
519- state_.seq_lens .push_back (num_decoding_tokens);
520- state_.q_seq_lens .push_back (num_decoding_tokens);
524+ state_.seq_lens .emplace_back (num_decoding_tokens);
525+ state_.q_seq_lens .emplace_back (num_decoding_tokens);
521526#elif defined(USE_MLU)
522- state_.seq_lens .push_back (state_.seq_lens .back () + num_decoding_tokens);
523- state_.q_seq_lens .push_back (state_.q_seq_lens .back () +
524- num_decoding_tokens);
527+ state_.seq_lens .emplace_back (state_.seq_lens .back () +
528+ num_decoding_tokens);
529+ state_.q_seq_lens .emplace_back (state_.q_seq_lens .back () +
530+ num_decoding_tokens);
525531#endif
526532 state_.block_tables_vec .emplace_back ();
527533 }
@@ -554,6 +560,8 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
554560 input_params.kv_max_seq_len = state_.max_seq_len ;
555561 input_params.q_max_seq_len = state_.q_max_seq_len ;
556562 input_params.kv_seq_lens = torch::tensor (state_.seq_lens , torch::kInt );
563+ input_params.kv_cache_tokens_nums =
564+ torch::tensor (state_.kv_cache_tokens_nums , torch::kInt );
557565 input_params.q_seq_lens = torch::tensor (state_.q_seq_lens , torch::kInt );
558566 input_params.kv_seq_lens_vec = std::move (state_.seq_lens );
559567 input_params.q_seq_lens_vec = std::move (state_.q_seq_lens );
@@ -640,6 +648,9 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
640648 raw_forward_input.q_max_seq_len = state_.q_max_seq_len ;
641649 raw_forward_input.seq_lens = std::move (state_.seq_lens );
642650 raw_forward_input.q_seq_lens = std::move (state_.q_seq_lens );
651+ raw_forward_input.kv_cache_tokens_nums =
652+ std::move (state_.kv_cache_tokens_nums );
653+
643654 raw_forward_input.new_token_slot_ids = std::move (state_.new_token_slot_ids );
644655 raw_forward_input.block_tables_vec = std::move (state_.block_tables_vec );
645656 raw_forward_input.num_sequences = num_sequences_;
@@ -702,17 +713,17 @@ void BatchInputBuilder::process_swap_block_infos(
702713 src_indices.reserve (swap_blocks.size ());
703714 dst_indices.reserve (swap_blocks.size ());
704715
705- src_indices.push_back (swap_blocks[0 ].device_block_id );
706- dst_indices.push_back (swap_blocks[0 ].host_block_id );
716+ src_indices.emplace_back (swap_blocks[0 ].device_block_id );
717+ dst_indices.emplace_back (swap_blocks[0 ].host_block_id );
707718 for (size_t i = 1 ; i < swap_blocks.size (); i++) {
708- dst_indices.push_back (swap_blocks[i].host_block_id );
719+ dst_indices.emplace_back (swap_blocks[i].host_block_id );
709720 if (swap_blocks[i].device_block_id != current_src) {
710- src_indices.push_back (swap_blocks[i].device_block_id );
711- cum_sum.push_back (i);
721+ src_indices.emplace_back (swap_blocks[i].device_block_id );
722+ cum_sum.emplace_back (i);
712723 current_src = swap_blocks[i].device_block_id ;
713724 }
714725 }
715- cum_sum.push_back (swap_blocks.size ());
726+ cum_sum.emplace_back (swap_blocks.size ());
716727
717728 raw_forward_input.swap_blocks .clear ();
718729 raw_forward_input.src_block_indices = std::move (src_indices);
0 commit comments