Skip to content

Commit 9204bbd

Browse files
committed
feat: support deepseek prefixcache.
1 parent 086d258 commit 9204bbd

15 files changed

+264
-75
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
189189
static_cast<int32_t>(state_.flatten_tokens_vec.size()) -
190190
static_cast<int32_t>(state.flatten_tokens_vec.size());
191191
for (const auto& idx : state.selected_token_idxes) {
192-
state_.selected_token_idxes.push_back(idx + selected_token_idxes_offset);
192+
state_.selected_token_idxes.emplace_back(idx +
193+
selected_token_idxes_offset);
193194
}
194195
state_.sampling_params.insert(state_.sampling_params.end(),
195196
state.sampling_params.begin(),
196197
state.sampling_params.end());
197198
int32_t sample_idxes_offset =
198199
static_cast<int32_t>(state_.sample_idxes.size());
199200
for (const auto& idx : state.sample_idxes) {
200-
state_.sample_idxes.push_back(idx + sample_idxes_offset);
201+
state_.sample_idxes.emplace_back(idx + sample_idxes_offset);
201202
}
202203
state_.unique_token_ids_vec.insert(state_.unique_token_ids_vec.end(),
203204
state.unique_token_ids_vec.begin(),
@@ -217,15 +218,18 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
217218
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
218219
state.q_seq_lens.begin(),
219220
state.q_seq_lens.end());
221+
state_.kv_cache_tokens_nums.insert(state_.kv_cache_tokens_nums.end(),
222+
state.kv_cache_tokens_nums.begin(),
223+
state.kv_cache_tokens_nums.end());
220224
#elif defined(USE_MLU)
221225
int32_t seq_len_offset = state_.seq_lens.back();
222226
// skip the first element which is 0
223227
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
224-
state_.seq_lens.push_back(state.seq_lens[i] + seq_len_offset);
228+
state_.seq_lens.emplace_back(state.seq_lens[i] + seq_len_offset);
225229
}
226230
int32_t q_seq_len_offset = state_.q_seq_lens.back();
227231
for (size_t i = 1; i < state.q_seq_lens.size(); ++i) {
228-
state_.q_seq_lens.push_back(state.q_seq_lens[i] + q_seq_len_offset);
232+
state_.q_seq_lens.emplace_back(state.q_seq_lens[i] + q_seq_len_offset);
229233
}
230234
#endif
231235
state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(),
@@ -286,12 +290,13 @@ void BatchInputBuilder::process_single_sequence(
286290
state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0);
287291
state.max_seq_len = std::max(state.max_seq_len, seq_len);
288292
state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len);
293+
state.kv_cache_tokens_nums.emplace_back(n_kv_cache_tokens);
289294
#if defined(USE_NPU)
290-
state.seq_lens.push_back(seq_len);
291-
state.q_seq_lens.push_back(q_seq_len);
295+
state.seq_lens.emplace_back(seq_len);
296+
state.q_seq_lens.emplace_back(q_seq_len);
292297
#elif defined(USE_MLU)
293-
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
294-
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
298+
state.seq_lens.emplace_back(state.seq_lens.back() + seq_len);
299+
state.q_seq_lens.emplace_back(state.q_seq_lens.back() + q_seq_len);
295300
#endif
296301
// Process tokens and positions
297302
extract_tokens_and_positions(sequence, n_kv_cache_tokens, seq_len, state_ptr);
@@ -317,8 +322,8 @@ void BatchInputBuilder::process_single_sequence(
317322
// Input for beam search kernel
318323
if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() &&
319324
sequence->num_generated_tokens() > 0) {
320-
state.acc_logprob_vec.push_back(sequence->get_average_logprob() *
321-
sequence->num_generated_tokens());
325+
state.acc_logprob_vec.emplace_back(sequence->get_average_logprob() *
326+
sequence->num_generated_tokens());
322327
}
323328
}
324329

@@ -343,15 +348,15 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
343348
if (use_mrope_) {
344349
const auto& args = *args_;
345350
MPositionHelper helper(*sequence, args);
346-
state.mrope_positions_vec.push_back(helper.get_positions());
351+
state.mrope_positions_vec.emplace_back(helper.get_positions());
347352
}
348353

349354
// Process each token
350355
for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) {
351-
state.flatten_tokens_vec.push_back(token_ids[j]);
356+
state.flatten_tokens_vec.emplace_back(token_ids[j]);
352357

353358
if (!use_mrope_) {
354-
state.flatten_positions_vec.push_back(static_cast<int32_t>(j));
359+
state.flatten_positions_vec.emplace_back(static_cast<int32_t>(j));
355360
}
356361

357362
// Handle sampling for last tokens
@@ -365,10 +370,10 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
365370
if (n_tokens == seq_len) {
366371
// last chunk of prefill and decode
367372
// add -1 as extra token id
368-
state.extra_token_ids.push_back(-1);
369-
state.embedding_ids.push_back(sequence->get_embedding_id());
373+
state.extra_token_ids.emplace_back(-1);
374+
state.embedding_ids.emplace_back(sequence->get_embedding_id());
370375
} else {
371-
state.extra_token_ids.push_back(token_ids[seq_len]);
376+
state.extra_token_ids.emplace_back(token_ids[seq_len]);
372377
}
373378
}
374379

@@ -387,8 +392,8 @@ void BatchInputBuilder::handle_sampling_parameters(
387392
--adjusted_token_to_count_map[token_id];
388393

389394
// Select token for sampling
390-
state.selected_token_idxes.push_back(state.flatten_tokens_vec.size() - 1);
391-
state.sampling_params.push_back(sequence->sampling_param());
395+
state.selected_token_idxes.emplace_back(state.flatten_tokens_vec.size() - 1);
396+
state.sampling_params.emplace_back(sequence->sampling_param());
392397

393398
// Process unique tokens
394399
const auto& seq_token_counts = sequence->token_to_count_map();
@@ -404,19 +409,19 @@ void BatchInputBuilder::handle_sampling_parameters(
404409
(it != adjusted_token_to_count_map.end()) ? it->second : 0;
405410

406411
if (count > adjust_count) {
407-
ids.push_back(token_id);
408-
counts.push_back(count - adjust_count);
412+
ids.emplace_back(token_id);
413+
counts.emplace_back(count - adjust_count);
409414
}
410415
}
411416

412-
state.unique_token_lens_vec.push_back(static_cast<int32_t>(ids.size()));
417+
state.unique_token_lens_vec.emplace_back(static_cast<int32_t>(ids.size()));
413418

414419
// Mark sample token if it's the last token
415420
// TODO add test
416421
// in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000,
417422
// n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384
418423
if (token_position == seq_len - 1) {
419-
state.sample_idxes.push_back(
424+
state.sample_idxes.emplace_back(
420425
static_cast<int32_t>(state.selected_token_idxes.size() - 1));
421426
}
422427
}
@@ -447,7 +452,7 @@ void BatchInputBuilder::setup_kv_cache_info(
447452
int32_t block_size = 0;
448453
for (const auto& block : blocks) {
449454
block_size = block.size();
450-
block_ids.push_back(block.id());
455+
block_ids.emplace_back(block.id());
451456
u_block_ids.emplace_back(block.id());
452457
}
453458

@@ -483,13 +488,13 @@ void BatchInputBuilder::setup_continuous_kv_cache_info(
483488
std::vector<int64_t> cache_slot_offsets;
484489
cache_slot_offsets.reserve(seq_len - n_kv_cache_tokens);
485490
for (int32_t i = n_kv_cache_tokens; i < seq_len; ++i) {
486-
cache_slot_offsets.push_back(kv_cache_start_offset +
487-
i * FLAGS_cache_size_per_token);
491+
cache_slot_offsets.emplace_back(kv_cache_start_offset +
492+
i * FLAGS_cache_size_per_token);
488493
}
489494
state.new_cache_slot_offsets.insert(state.new_cache_slot_offsets.end(),
490495
cache_slot_offsets.begin(),
491496
cache_slot_offsets.end());
492-
state.kv_cache_start_offsets.push_back(kv_cache_start_offset);
497+
state.kv_cache_start_offsets.emplace_back(kv_cache_start_offset);
493498
}
494499

495500
void BatchInputBuilder::padding_decode_batch_size(
@@ -506,22 +511,23 @@ void BatchInputBuilder::padding_decode_batch_size(
506511
// add padding tokens to the batch
507512
for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) {
508513
for (int32_t k = 0; k < num_decoding_tokens; ++k) {
509-
state_.flatten_tokens_vec.push_back(0);
514+
state_.flatten_tokens_vec.emplace_back(0);
510515
if (!use_mrope_) {
511-
state_.flatten_positions_vec.push_back(0);
516+
state_.flatten_positions_vec.emplace_back(0);
512517
} else {
513-
state_.mrope_positions_vec.push_back(
518+
state_.mrope_positions_vec.emplace_back(
514519
torch::zeros({3, 1}, torch::kInt));
515520
}
516-
state_.new_token_slot_ids.push_back(0);
521+
state_.new_token_slot_ids.emplace_back(0);
517522
}
518523
#if defined(USE_NPU)
519-
state_.seq_lens.push_back(num_decoding_tokens);
520-
state_.q_seq_lens.push_back(num_decoding_tokens);
524+
state_.seq_lens.emplace_back(num_decoding_tokens);
525+
state_.q_seq_lens.emplace_back(num_decoding_tokens);
521526
#elif defined(USE_MLU)
522-
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
523-
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
524-
num_decoding_tokens);
527+
state_.seq_lens.emplace_back(state_.seq_lens.back() +
528+
num_decoding_tokens);
529+
state_.q_seq_lens.emplace_back(state_.q_seq_lens.back() +
530+
num_decoding_tokens);
525531
#endif
526532
state_.block_tables_vec.emplace_back();
527533
}
@@ -554,6 +560,8 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
554560
input_params.kv_max_seq_len = state_.max_seq_len;
555561
input_params.q_max_seq_len = state_.q_max_seq_len;
556562
input_params.kv_seq_lens = torch::tensor(state_.seq_lens, torch::kInt);
563+
input_params.kv_cache_tokens_nums =
564+
torch::tensor(state_.kv_cache_tokens_nums, torch::kInt);
557565
input_params.q_seq_lens = torch::tensor(state_.q_seq_lens, torch::kInt);
558566
input_params.kv_seq_lens_vec = std::move(state_.seq_lens);
559567
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
@@ -640,6 +648,9 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
640648
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
641649
raw_forward_input.seq_lens = std::move(state_.seq_lens);
642650
raw_forward_input.q_seq_lens = std::move(state_.q_seq_lens);
651+
raw_forward_input.kv_cache_tokens_nums =
652+
std::move(state_.kv_cache_tokens_nums);
653+
643654
raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids);
644655
raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec);
645656
raw_forward_input.num_sequences = num_sequences_;
@@ -702,17 +713,17 @@ void BatchInputBuilder::process_swap_block_infos(
702713
src_indices.reserve(swap_blocks.size());
703714
dst_indices.reserve(swap_blocks.size());
704715

705-
src_indices.push_back(swap_blocks[0].device_block_id);
706-
dst_indices.push_back(swap_blocks[0].host_block_id);
716+
src_indices.emplace_back(swap_blocks[0].device_block_id);
717+
dst_indices.emplace_back(swap_blocks[0].host_block_id);
707718
for (size_t i = 1; i < swap_blocks.size(); i++) {
708-
dst_indices.push_back(swap_blocks[i].host_block_id);
719+
dst_indices.emplace_back(swap_blocks[i].host_block_id);
709720
if (swap_blocks[i].device_block_id != current_src) {
710-
src_indices.push_back(swap_blocks[i].device_block_id);
711-
cum_sum.push_back(i);
721+
src_indices.emplace_back(swap_blocks[i].device_block_id);
722+
cum_sum.emplace_back(i);
712723
current_src = swap_blocks[i].device_block_id;
713724
}
714725
}
715-
cum_sum.push_back(swap_blocks.size());
726+
cum_sum.emplace_back(swap_blocks.size());
716727

717728
raw_forward_input.swap_blocks.clear();
718729
raw_forward_input.src_block_indices = std::move(src_indices);

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class BatchInputBuilder {
8686
BatchForwardType batch_forward_type;
8787
uint32_t max_seq_len = 0;
8888
uint32_t q_max_seq_len = 0;
89+
std::vector<int32_t> kv_cache_tokens_nums;
8990
#if defined(USE_NPU)
9091
std::vector<int32_t> seq_lens;
9192
std::vector<int32_t> q_seq_lens;

xllm/core/framework/model/model_input_params.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ struct ModelInputParams {
7878
params.embedding_ids = std::move(embedding_ids);
7979
params.extra_token_ids = std::move(extra_token_ids);
8080
params.dp_ep_padding_data = dp_ep_padding_data;
81+
params.kv_cache_tokens_nums_host = std::vector<int>(kv_cache_tokens_nums.data_ptr<int>(),
82+
kv_cache_tokens_nums.data_ptr<int>() +
83+
kv_cache_tokens_nums.numel());
84+
85+
params.kv_cache_tokens_nums = safe_to(kv_cache_tokens_nums,device);
86+
params.history_compressed_kv = safe_to(history_compressed_kv, device);
87+
params.history_k_rope = safe_to(history_k_rope, device);
88+
params.ring_cur_seqlen = safe_to(ring_cur_seqlen, device);
89+
params.ring_cur_seqlen_host = ring_cur_seqlen_host;
90+
params.ring_cache_seqlen = safe_to(ring_cache_seqlen, device);
91+
params.ring_cache_seqlen_host = ring_cache_seqlen_host;
8192
#if defined(USE_NPU)
8293
params.layer_synchronizer = layer_synchronizer;
8394
#endif
@@ -198,8 +209,17 @@ struct ModelInputParams {
198209
#endif
199210

200211
DpEpPaddingData dp_ep_padding_data;
212+
201213
torch::Tensor expert_load_data;
202214

215+
torch::Tensor kv_cache_tokens_nums;
216+
std::vector<int> kv_cache_tokens_nums_host;
217+
torch::Tensor history_compressed_kv;
218+
torch::Tensor history_k_rope;
219+
torch::Tensor ring_cur_seqlen;
220+
std::vector<int> ring_cur_seqlen_host;
221+
torch::Tensor ring_cache_seqlen;
222+
std::vector<int> ring_cache_seqlen_host;
203223
// new slot offsets for continuous kvcache
204224
// used to store kv-cache to right position
205225
// IntTensor: [n_tokens]

0 commit comments

Comments
 (0)