Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
static_cast<int32_t>(state_.flatten_tokens_vec.size()) -
static_cast<int32_t>(state.flatten_tokens_vec.size());
for (const auto& idx : state.selected_token_idxes) {
state_.selected_token_idxes.push_back(idx + selected_token_idxes_offset);
state_.selected_token_idxes.emplace_back(idx +
selected_token_idxes_offset);
}
state_.sampling_params.insert(state_.sampling_params.end(),
state.sampling_params.begin(),
state.sampling_params.end());
int32_t sample_idxes_offset =
static_cast<int32_t>(state_.sample_idxes.size());
for (const auto& idx : state.sample_idxes) {
state_.sample_idxes.push_back(idx + sample_idxes_offset);
state_.sample_idxes.emplace_back(idx + sample_idxes_offset);
}
state_.unique_token_ids_vec.insert(state_.unique_token_ids_vec.end(),
state.unique_token_ids_vec.begin(),
Expand All @@ -217,15 +218,18 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
state.q_seq_lens.begin(),
state.q_seq_lens.end());
state_.kv_cache_tokens_nums.insert(state_.kv_cache_tokens_nums.end(),
state.kv_cache_tokens_nums.begin(),
state.kv_cache_tokens_nums.end());
#elif defined(USE_MLU)
int32_t seq_len_offset = state_.seq_lens.back();
// skip the first element which is 0
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
state_.seq_lens.push_back(state.seq_lens[i] + seq_len_offset);
state_.seq_lens.emplace_back(state.seq_lens[i] + seq_len_offset);
}
int32_t q_seq_len_offset = state_.q_seq_lens.back();
for (size_t i = 1; i < state.q_seq_lens.size(); ++i) {
state_.q_seq_lens.push_back(state.q_seq_lens[i] + q_seq_len_offset);
state_.q_seq_lens.emplace_back(state.q_seq_lens[i] + q_seq_len_offset);
}
#endif
state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(),
Expand Down Expand Up @@ -286,12 +290,13 @@ void BatchInputBuilder::process_single_sequence(
state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0);
state.max_seq_len = std::max(state.max_seq_len, seq_len);
state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len);
state.kv_cache_tokens_nums.emplace_back(n_kv_cache_tokens);
#if defined(USE_NPU)
state.seq_lens.push_back(seq_len);
state.q_seq_lens.push_back(q_seq_len);
state.seq_lens.emplace_back(seq_len);
state.q_seq_lens.emplace_back(q_seq_len);
#elif defined(USE_MLU)
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
state.seq_lens.emplace_back(state.seq_lens.back() + seq_len);
state.q_seq_lens.emplace_back(state.q_seq_lens.back() + q_seq_len);
#endif
// Process tokens and positions
extract_tokens_and_positions(sequence, n_kv_cache_tokens, seq_len, state_ptr);
Expand All @@ -317,8 +322,8 @@ void BatchInputBuilder::process_single_sequence(
// Input for beam search kernel
if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() &&
sequence->num_generated_tokens() > 0) {
state.acc_logprob_vec.push_back(sequence->get_average_logprob() *
sequence->num_generated_tokens());
state.acc_logprob_vec.emplace_back(sequence->get_average_logprob() *
sequence->num_generated_tokens());
}
}

Expand All @@ -343,15 +348,15 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
if (use_mrope_) {
const auto& args = *args_;
MPositionHelper helper(*sequence, args);
state.mrope_positions_vec.push_back(helper.get_positions());
state.mrope_positions_vec.emplace_back(helper.get_positions());
}

// Process each token
for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) {
state.flatten_tokens_vec.push_back(token_ids[j]);
state.flatten_tokens_vec.emplace_back(token_ids[j]);

if (!use_mrope_) {
state.flatten_positions_vec.push_back(static_cast<int32_t>(j));
state.flatten_positions_vec.emplace_back(static_cast<int32_t>(j));
}

// Handle sampling for last tokens
Expand All @@ -365,10 +370,10 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
if (n_tokens == seq_len) {
// last chunk of prefill and decode
// add -1 as extra token id
state.extra_token_ids.push_back(-1);
state.embedding_ids.push_back(sequence->get_embedding_id());
state.extra_token_ids.emplace_back(-1);
state.embedding_ids.emplace_back(sequence->get_embedding_id());
} else {
state.extra_token_ids.push_back(token_ids[seq_len]);
state.extra_token_ids.emplace_back(token_ids[seq_len]);
}
}

Expand All @@ -387,8 +392,8 @@ void BatchInputBuilder::handle_sampling_parameters(
--adjusted_token_to_count_map[token_id];

// Select token for sampling
state.selected_token_idxes.push_back(state.flatten_tokens_vec.size() - 1);
state.sampling_params.push_back(sequence->sampling_param());
state.selected_token_idxes.emplace_back(state.flatten_tokens_vec.size() - 1);
state.sampling_params.emplace_back(sequence->sampling_param());

// Process unique tokens
const auto& seq_token_counts = sequence->token_to_count_map();
Expand All @@ -404,19 +409,19 @@ void BatchInputBuilder::handle_sampling_parameters(
(it != adjusted_token_to_count_map.end()) ? it->second : 0;

if (count > adjust_count) {
ids.push_back(token_id);
counts.push_back(count - adjust_count);
ids.emplace_back(token_id);
counts.emplace_back(count - adjust_count);
}
}

state.unique_token_lens_vec.push_back(static_cast<int32_t>(ids.size()));
state.unique_token_lens_vec.emplace_back(static_cast<int32_t>(ids.size()));

// Mark sample token if it's the last token
// TODO add test
// in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000,
// n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384
if (token_position == seq_len - 1) {
state.sample_idxes.push_back(
state.sample_idxes.emplace_back(
static_cast<int32_t>(state.selected_token_idxes.size() - 1));
}
}
Expand Down Expand Up @@ -447,7 +452,7 @@ void BatchInputBuilder::setup_kv_cache_info(
int32_t block_size = 0;
for (const auto& block : blocks) {
block_size = block.size();
block_ids.push_back(block.id());
block_ids.emplace_back(block.id());
u_block_ids.emplace_back(block.id());
}

Expand Down Expand Up @@ -483,13 +488,13 @@ void BatchInputBuilder::setup_continuous_kv_cache_info(
std::vector<int64_t> cache_slot_offsets;
cache_slot_offsets.reserve(seq_len - n_kv_cache_tokens);
for (int32_t i = n_kv_cache_tokens; i < seq_len; ++i) {
cache_slot_offsets.push_back(kv_cache_start_offset +
i * FLAGS_cache_size_per_token);
cache_slot_offsets.emplace_back(kv_cache_start_offset +
i * FLAGS_cache_size_per_token);
}
state.new_cache_slot_offsets.insert(state.new_cache_slot_offsets.end(),
cache_slot_offsets.begin(),
cache_slot_offsets.end());
state.kv_cache_start_offsets.push_back(kv_cache_start_offset);
state.kv_cache_start_offsets.emplace_back(kv_cache_start_offset);
}

void BatchInputBuilder::padding_decode_batch_size(
Expand All @@ -506,22 +511,23 @@ void BatchInputBuilder::padding_decode_batch_size(
// add padding tokens to the batch
for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) {
for (int32_t k = 0; k < num_decoding_tokens; ++k) {
state_.flatten_tokens_vec.push_back(0);
state_.flatten_tokens_vec.emplace_back(0);
if (!use_mrope_) {
state_.flatten_positions_vec.push_back(0);
state_.flatten_positions_vec.emplace_back(0);
} else {
state_.mrope_positions_vec.push_back(
state_.mrope_positions_vec.emplace_back(
torch::zeros({3, 1}, torch::kInt));
}
state_.new_token_slot_ids.push_back(0);
state_.new_token_slot_ids.emplace_back(0);
}
#if defined(USE_NPU)
state_.seq_lens.push_back(num_decoding_tokens);
state_.q_seq_lens.push_back(num_decoding_tokens);
state_.seq_lens.emplace_back(num_decoding_tokens);
state_.q_seq_lens.emplace_back(num_decoding_tokens);
#elif defined(USE_MLU)
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
num_decoding_tokens);
state_.seq_lens.emplace_back(state_.seq_lens.back() +
num_decoding_tokens);
state_.q_seq_lens.emplace_back(state_.q_seq_lens.back() +
num_decoding_tokens);
#endif
state_.block_tables_vec.emplace_back();
}
Expand Down Expand Up @@ -554,6 +560,8 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
input_params.kv_max_seq_len = state_.max_seq_len;
input_params.q_max_seq_len = state_.q_max_seq_len;
input_params.kv_seq_lens = torch::tensor(state_.seq_lens, torch::kInt);
input_params.kv_cache_tokens_nums =
torch::tensor(state_.kv_cache_tokens_nums, torch::kInt);
input_params.q_seq_lens = torch::tensor(state_.q_seq_lens, torch::kInt);
input_params.kv_seq_lens_vec = std::move(state_.seq_lens);
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
Expand Down Expand Up @@ -640,6 +648,9 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
raw_forward_input.seq_lens = std::move(state_.seq_lens);
raw_forward_input.q_seq_lens = std::move(state_.q_seq_lens);
raw_forward_input.kv_cache_tokens_nums =
std::move(state_.kv_cache_tokens_nums);

raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids);
raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec);
raw_forward_input.num_sequences = num_sequences_;
Expand Down Expand Up @@ -702,17 +713,17 @@ void BatchInputBuilder::process_swap_block_infos(
src_indices.reserve(swap_blocks.size());
dst_indices.reserve(swap_blocks.size());

src_indices.push_back(swap_blocks[0].device_block_id);
dst_indices.push_back(swap_blocks[0].host_block_id);
src_indices.emplace_back(swap_blocks[0].device_block_id);
dst_indices.emplace_back(swap_blocks[0].host_block_id);
for (size_t i = 1; i < swap_blocks.size(); i++) {
dst_indices.push_back(swap_blocks[i].host_block_id);
dst_indices.emplace_back(swap_blocks[i].host_block_id);
if (swap_blocks[i].device_block_id != current_src) {
src_indices.push_back(swap_blocks[i].device_block_id);
cum_sum.push_back(i);
src_indices.emplace_back(swap_blocks[i].device_block_id);
cum_sum.emplace_back(i);
current_src = swap_blocks[i].device_block_id;
}
}
cum_sum.push_back(swap_blocks.size());
cum_sum.emplace_back(swap_blocks.size());

raw_forward_input.swap_blocks.clear();
raw_forward_input.src_block_indices = std::move(src_indices);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class BatchInputBuilder {
BatchForwardType batch_forward_type;
uint32_t max_seq_len = 0;
uint32_t q_max_seq_len = 0;
std::vector<int32_t> kv_cache_tokens_nums;
#if defined(USE_NPU)
std::vector<int32_t> seq_lens;
std::vector<int32_t> q_seq_lens;
Expand Down
20 changes: 20 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ struct ModelInputParams {
params.embedding_ids = std::move(embedding_ids);
params.extra_token_ids = std::move(extra_token_ids);
params.dp_ep_padding_data = dp_ep_padding_data;
params.kv_cache_tokens_nums_host = std::vector<int>(kv_cache_tokens_nums.data_ptr<int>(),
kv_cache_tokens_nums.data_ptr<int>() +
kv_cache_tokens_nums.numel());

params.kv_cache_tokens_nums = safe_to(kv_cache_tokens_nums,device);
params.history_compressed_kv = safe_to(history_compressed_kv, device);
params.history_k_rope = safe_to(history_k_rope, device);
params.ring_cur_seqlen = safe_to(ring_cur_seqlen, device);
params.ring_cur_seqlen_host = ring_cur_seqlen_host;
params.ring_cache_seqlen = safe_to(ring_cache_seqlen, device);
params.ring_cache_seqlen_host = ring_cache_seqlen_host;
#if defined(USE_NPU)
params.layer_synchronizer = layer_synchronizer;
#endif
Expand Down Expand Up @@ -198,8 +209,17 @@ struct ModelInputParams {
#endif

DpEpPaddingData dp_ep_padding_data;

torch::Tensor expert_load_data;

torch::Tensor kv_cache_tokens_nums;
std::vector<int> kv_cache_tokens_nums_host;
torch::Tensor history_compressed_kv;
torch::Tensor history_k_rope;
torch::Tensor ring_cur_seqlen;
std::vector<int> ring_cur_seqlen_host;
torch::Tensor ring_cache_seqlen;
std::vector<int> ring_cache_seqlen_host;
// new slot offsets for continuous kvcache
// used to store kv-cache to right position
// IntTensor: [n_tokens]
Expand Down
Loading