diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 7989ba6dc..a69204237 100755 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -189,7 +189,8 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, static_cast(state_.flatten_tokens_vec.size()) - static_cast(state.flatten_tokens_vec.size()); for (const auto& idx : state.selected_token_idxes) { - state_.selected_token_idxes.push_back(idx + selected_token_idxes_offset); + state_.selected_token_idxes.emplace_back(idx + + selected_token_idxes_offset); } state_.sampling_params.insert(state_.sampling_params.end(), state.sampling_params.begin(), @@ -197,7 +198,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, int32_t sample_idxes_offset = static_cast(state_.sample_idxes.size()); for (const auto& idx : state.sample_idxes) { - state_.sample_idxes.push_back(idx + sample_idxes_offset); + state_.sample_idxes.emplace_back(idx + sample_idxes_offset); } state_.unique_token_ids_vec.insert(state_.unique_token_ids_vec.end(), state.unique_token_ids_vec.begin(), @@ -217,15 +218,18 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.q_seq_lens.insert(state_.q_seq_lens.end(), state.q_seq_lens.begin(), state.q_seq_lens.end()); + state_.kv_cache_tokens_nums.insert(state_.kv_cache_tokens_nums.end(), + state.kv_cache_tokens_nums.begin(), + state.kv_cache_tokens_nums.end()); #elif defined(USE_MLU) int32_t seq_len_offset = state_.seq_lens.back(); // skip the first element which is 0 for (size_t i = 1; i < state.seq_lens.size(); ++i) { - state_.seq_lens.push_back(state.seq_lens[i] + seq_len_offset); + state_.seq_lens.emplace_back(state.seq_lens[i] + seq_len_offset); } int32_t q_seq_len_offset = state_.q_seq_lens.back(); for (size_t i = 1; i < state.q_seq_lens.size(); ++i) { - state_.q_seq_lens.push_back(state.q_seq_lens[i] + q_seq_len_offset); + state_.q_seq_lens.emplace_back(state.q_seq_lens[i] + q_seq_len_offset); } #endif state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(), @@ -286,12 +290,13 @@ void BatchInputBuilder::process_single_sequence( state.empty_kv_cache = state.empty_kv_cache && (n_kv_cache_tokens == 0); state.max_seq_len = std::max(state.max_seq_len, seq_len); state.q_max_seq_len = std::max(state.q_max_seq_len, q_seq_len); + state.kv_cache_tokens_nums.emplace_back(n_kv_cache_tokens); #if defined(USE_NPU) - state.seq_lens.push_back(seq_len); - state.q_seq_lens.push_back(q_seq_len); + state.seq_lens.emplace_back(seq_len); + state.q_seq_lens.emplace_back(q_seq_len); #elif defined(USE_MLU) - state.seq_lens.push_back(state.seq_lens.back() + seq_len); - state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len); + state.seq_lens.emplace_back(state.seq_lens.back() + seq_len); + state.q_seq_lens.emplace_back(state.q_seq_lens.back() + q_seq_len); #endif // Process tokens and positions extract_tokens_and_positions(sequence, n_kv_cache_tokens, seq_len, state_ptr); @@ -317,8 +322,8 @@ void BatchInputBuilder::process_single_sequence( // Input for beam search kernel if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() && sequence->num_generated_tokens() > 0) { - state.acc_logprob_vec.push_back(sequence->get_average_logprob() * - sequence->num_generated_tokens()); + state.acc_logprob_vec.emplace_back(sequence->get_average_logprob() * + sequence->num_generated_tokens()); } } @@ -343,15 +348,15 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, if (use_mrope_) { const auto& args = *args_; MPositionHelper helper(*sequence, args); - state.mrope_positions_vec.push_back(helper.get_positions()); + state.mrope_positions_vec.emplace_back(helper.get_positions()); } // Process each token for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) { - state.flatten_tokens_vec.push_back(token_ids[j]); + state.flatten_tokens_vec.emplace_back(token_ids[j]); if (!use_mrope_) { - state.flatten_positions_vec.push_back(static_cast(j)); + state.flatten_positions_vec.emplace_back(static_cast(j)); } // Handle sampling for last tokens @@ -365,10 +370,10 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, if (n_tokens == seq_len) { // last chunk of prefill and decode // add -1 as extra token id - state.extra_token_ids.push_back(-1); - state.embedding_ids.push_back(sequence->get_embedding_id()); + state.extra_token_ids.emplace_back(-1); + state.embedding_ids.emplace_back(sequence->get_embedding_id()); } else { - state.extra_token_ids.push_back(token_ids[seq_len]); + state.extra_token_ids.emplace_back(token_ids[seq_len]); } } @@ -387,8 +392,8 @@ void BatchInputBuilder::handle_sampling_parameters( --adjusted_token_to_count_map[token_id]; // Select token for sampling - state.selected_token_idxes.push_back(state.flatten_tokens_vec.size() - 1); - state.sampling_params.push_back(sequence->sampling_param()); + state.selected_token_idxes.emplace_back(state.flatten_tokens_vec.size() - 1); + state.sampling_params.emplace_back(sequence->sampling_param()); // Process unique tokens const auto& seq_token_counts = sequence->token_to_count_map(); @@ -404,19 +409,19 @@ void BatchInputBuilder::handle_sampling_parameters( (it != adjusted_token_to_count_map.end()) ? it->second : 0; if (count > adjust_count) { - ids.push_back(token_id); - counts.push_back(count - adjust_count); + ids.emplace_back(token_id); + counts.emplace_back(count - adjust_count); } } - state.unique_token_lens_vec.push_back(static_cast(ids.size())); + state.unique_token_lens_vec.emplace_back(static_cast(ids.size())); // Mark sample token if it's the last token // TODO add test // in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000, // n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384 if (token_position == seq_len - 1) { - state.sample_idxes.push_back( + state.sample_idxes.emplace_back( static_cast(state.selected_token_idxes.size() - 1)); } } @@ -447,7 +452,7 @@ void BatchInputBuilder::setup_kv_cache_info( int32_t block_size = 0; for (const auto& block : blocks) { block_size = block.size(); - block_ids.push_back(block.id()); + block_ids.emplace_back(block.id()); u_block_ids.emplace_back(block.id()); } @@ -483,13 +488,13 @@ void BatchInputBuilder::setup_continuous_kv_cache_info( std::vector cache_slot_offsets; cache_slot_offsets.reserve(seq_len - n_kv_cache_tokens); for (int32_t i = n_kv_cache_tokens; i < seq_len; ++i) { - cache_slot_offsets.push_back(kv_cache_start_offset + - i * FLAGS_cache_size_per_token); + cache_slot_offsets.emplace_back(kv_cache_start_offset + + i * FLAGS_cache_size_per_token); } state.new_cache_slot_offsets.insert(state.new_cache_slot_offsets.end(), cache_slot_offsets.begin(), cache_slot_offsets.end()); - state.kv_cache_start_offsets.push_back(kv_cache_start_offset); + state.kv_cache_start_offsets.emplace_back(kv_cache_start_offset); } void BatchInputBuilder::padding_decode_batch_size( @@ -506,22 +511,23 @@ void BatchInputBuilder::padding_decode_batch_size( // add padding tokens to the batch for (int32_t i = num_sequences_; i < min_decoding_batch_size; ++i) { for (int32_t k = 0; k < num_decoding_tokens; ++k) { - state_.flatten_tokens_vec.push_back(0); + state_.flatten_tokens_vec.emplace_back(0); if (!use_mrope_) { - state_.flatten_positions_vec.push_back(0); + state_.flatten_positions_vec.emplace_back(0); } else { - state_.mrope_positions_vec.push_back( + state_.mrope_positions_vec.emplace_back( torch::zeros({3, 1}, torch::kInt)); } - state_.new_token_slot_ids.push_back(0); + state_.new_token_slot_ids.emplace_back(0); } #if defined(USE_NPU) - state_.seq_lens.push_back(num_decoding_tokens); - state_.q_seq_lens.push_back(num_decoding_tokens); + state_.seq_lens.emplace_back(num_decoding_tokens); + state_.q_seq_lens.emplace_back(num_decoding_tokens); #elif defined(USE_MLU) - state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens); - state_.q_seq_lens.push_back(state_.q_seq_lens.back() + - num_decoding_tokens); + state_.seq_lens.emplace_back(state_.seq_lens.back() + + num_decoding_tokens); + state_.q_seq_lens.emplace_back(state_.q_seq_lens.back() + + num_decoding_tokens); #endif state_.block_tables_vec.emplace_back(); } @@ -554,6 +560,8 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.kv_max_seq_len = state_.max_seq_len; input_params.q_max_seq_len = state_.q_max_seq_len; input_params.kv_seq_lens = torch::tensor(state_.seq_lens, torch::kInt); + input_params.kv_cache_tokens_nums = + torch::tensor(state_.kv_cache_tokens_nums, torch::kInt); input_params.q_seq_lens = torch::tensor(state_.q_seq_lens, torch::kInt); input_params.kv_seq_lens_vec = std::move(state_.seq_lens); input_params.q_seq_lens_vec = std::move(state_.q_seq_lens); @@ -640,6 +648,9 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.q_max_seq_len = state_.q_max_seq_len; raw_forward_input.seq_lens = std::move(state_.seq_lens); raw_forward_input.q_seq_lens = std::move(state_.q_seq_lens); + raw_forward_input.kv_cache_tokens_nums = + std::move(state_.kv_cache_tokens_nums); + raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids); raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec); raw_forward_input.num_sequences = num_sequences_; @@ -702,17 +713,17 @@ void BatchInputBuilder::process_swap_block_infos( src_indices.reserve(swap_blocks.size()); dst_indices.reserve(swap_blocks.size()); - src_indices.push_back(swap_blocks[0].device_block_id); - dst_indices.push_back(swap_blocks[0].host_block_id); + src_indices.emplace_back(swap_blocks[0].device_block_id); + dst_indices.emplace_back(swap_blocks[0].host_block_id); for (size_t i = 1; i < swap_blocks.size(); i++) { - dst_indices.push_back(swap_blocks[i].host_block_id); + dst_indices.emplace_back(swap_blocks[i].host_block_id); if (swap_blocks[i].device_block_id != current_src) { - src_indices.push_back(swap_blocks[i].device_block_id); - cum_sum.push_back(i); + src_indices.emplace_back(swap_blocks[i].device_block_id); + cum_sum.emplace_back(i); current_src = swap_blocks[i].device_block_id; } } - cum_sum.push_back(swap_blocks.size()); + cum_sum.emplace_back(swap_blocks.size()); raw_forward_input.swap_blocks.clear(); raw_forward_input.src_block_indices = std::move(src_indices); diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 212447249..1ef4b5174 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -86,6 +86,7 @@ class BatchInputBuilder { BatchForwardType batch_forward_type; uint32_t max_seq_len = 0; uint32_t q_max_seq_len = 0; + std::vector kv_cache_tokens_nums; #if defined(USE_NPU) std::vector seq_lens; std::vector q_seq_lens; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index a6227664b..dcf71d8d0 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -78,6 +78,17 @@ struct ModelInputParams { params.embedding_ids = std::move(embedding_ids); params.extra_token_ids = std::move(extra_token_ids); params.dp_ep_padding_data = dp_ep_padding_data; + params.kv_cache_tokens_nums_host = std::vector(kv_cache_tokens_nums.data_ptr(), + kv_cache_tokens_nums.data_ptr() + + kv_cache_tokens_nums.numel()); + + params.kv_cache_tokens_nums = safe_to(kv_cache_tokens_nums,device); + params.history_compressed_kv = safe_to(history_compressed_kv, device); + params.history_k_rope = safe_to(history_k_rope, device); + params.ring_cur_seqlen = safe_to(ring_cur_seqlen, device); + params.ring_cur_seqlen_host = ring_cur_seqlen_host; + params.ring_cache_seqlen = safe_to(ring_cache_seqlen, device); + params.ring_cache_seqlen_host = ring_cache_seqlen_host; #if defined(USE_NPU) params.layer_synchronizer = layer_synchronizer; #endif @@ -198,8 +209,17 @@ struct ModelInputParams { #endif DpEpPaddingData dp_ep_padding_data; + torch::Tensor expert_load_data; + torch::Tensor kv_cache_tokens_nums; + std::vector kv_cache_tokens_nums_host; + torch::Tensor history_compressed_kv; + torch::Tensor history_k_rope; + torch::Tensor ring_cur_seqlen; + std::vector ring_cur_seqlen_host; + torch::Tensor ring_cache_seqlen; + std::vector ring_cache_seqlen_host; // new slot offsets for continuous kvcache // used to store kv-cache to right position // IntTensor: [n_tokens] diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 608ff7e70..eca030306 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -284,8 +284,10 @@ NpuDeepseekV2DecoderLayerImpl::NpuDeepseekV2DecoderLayerImpl( CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_); dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_; - param_from_args(prefill_param_, model_args, parallel_args, true); - param_from_args(decode_param_, model_args, parallel_args, false); + param_from_args(prefill_param_, model_args, parallel_args, true, false); + param_from_args( + prefill_param_prefixcache_, model_args, parallel_args, true, true); + param_from_args(decode_param_, model_args, parallel_args, false, false); initialize_tensors(options); } @@ -346,8 +348,10 @@ void NpuDeepseekV2DecoderLayerImpl::param_from_args( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill) { - initialize_basic_parameters(param, args, parallel_args, is_prefill); + bool is_prefill, + bool is_prefixcache) { + initialize_basic_parameters( + param, args, parallel_args, is_prefill, is_prefixcache); initialize_attention_parameters(param, args, parallel_args); initialize_mlp_parameters(param, args, parallel_args); initialize_parallel_parameters(param, parallel_args); @@ -393,10 +397,14 @@ void NpuDeepseekV2DecoderLayerImpl::initialize_basic_parameters( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill) { + bool is_prefill, + bool is_prefixcache) { param.isFA = false; param.isPrefill = is_prefill; param.isBF16 = args.dtype() == "bfloat16"; + param.enablePrefixCache = + is_prefill && FLAGS_enable_prefix_cache && is_prefixcache; + param.isNzCache = FLAGS_enable_prefix_cache; param.enableSwiGLU = true; param.enableLcoc = true; // TODO: modify xllm_atb_layers @@ -467,7 +475,6 @@ void NpuDeepseekV2DecoderLayerImpl::initialize_attention_parameters( } param.enableFA3 = false; // TODO - param.isNzCache = false; // TODO param.enableKvQuantLayer = false; // TODO } @@ -1459,6 +1466,8 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_layer() { name_ = "deepseek_v2_decoder_layer " + std::to_string(layer_id_); model_name_ = "DeepSeek_V2"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + CHECK_OPERATION_STATUS_RETURN( + init_node(prefill_node_prefixcache_, prefill_param_prefixcache_)); CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); return atb::NO_ERROR; } @@ -1525,18 +1534,30 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( atb::Status st; // all micro batches are in same prefill/decode stage, // so, to judge empty_kv_cache, use input_params[0] here - if (input_params[0].global_empty_kv_cache) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); - st = execute_node(prefill_node_, node_id, event, event_flag); - LOG_IF(FATAL, st != 0) << model_name_ - << "excute prefill layer fail, error code: " << st; + if (input_params[0].batch_forward_type.is_chunked_prefill()) { + build_node_variant_pack(prefill_node_prefixcache_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + true); + st = execute_node(prefill_node_prefixcache_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ << "excute prefill layer fail, error code: " << st; + } else if(input_params[0].batch_forward_type.is_prefill()){ + build_node_variant_pack(prefill_node_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + true); + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ << "excute prefill layer fail, error code: " << st; } else { std::vector attn_mask{tensor_placeholder_, tensor_placeholder_}; @@ -1617,9 +1638,9 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 12) = atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 13) = - atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); + atb_speed::Utils::AtTensor2Tensor(input_params[0].kv_cache_tokens_nums); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 13).hostData = - const_cast(placeholder_vec_.data()); + const_cast(input_params[0].kv_cache_tokens_nums_host.data()); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 14) = atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); @@ -1865,8 +1886,9 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.dynamic_ep_idx()); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 29) = atb_speed::Utils::AtTensor2Tensor(dp_ep_padding.moe_idx()); + int offset = 30; if (FLAGS_enable_eplb && layer_id_ >= decode_param_.firstKDenseReplace) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30) = + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset++) = atb_speed::Utils::AtTensor2Tensor(expert_routing_map_); if (!is_prefill) { node.variantPack.outTensors.at(1) = atb_speed::Utils::AtTensor2Tensor( @@ -1874,6 +1896,23 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( decode_param_.firstKDenseReplace]); } } + if (input_params[0].batch_forward_type.is_chunked_prefill()) { + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset) = + atb_speed::Utils::AtTensor2Tensor( + input_params[0].history_compressed_kv); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset + 1) = + atb_speed::Utils::AtTensor2Tensor(input_params[0].history_k_rope); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset + 2) = + atb_speed::Utils::AtTensor2Tensor(input_params[0].ring_cur_seqlen); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset + 2) + .hostData = + const_cast(input_params[0].ring_cur_seqlen_host.data()); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset + 3) = + atb_speed::Utils::AtTensor2Tensor(input_params[0].ring_cache_seqlen); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + offset + 3) + .hostData = + const_cast(input_params[0].ring_cache_seqlen_host.data()); + } } for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index 00c830c4f..73a84baa7 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -151,7 +151,8 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { void param_from_args(atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill); + bool is_prefill, + bool is_prefixcache); void reserve_experts_weights(int num_of_device_experts); void initialize_device_expert_list(int numdevice, int num_layers); @@ -159,7 +160,8 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill); + bool is_prefill, + bool is_prefixcache); void initialize_attention_parameters( atb_speed::deepseekV2::DecoderLayerParam& param, @@ -312,9 +314,11 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { int32_t num_speculative_tokens_ = 0; atb_speed::deepseekV2::DecoderLayerParam prefill_param_; + atb_speed::deepseekV2::DecoderLayerParam prefill_param_prefixcache_; atb_speed::deepseekV2::DecoderLayerParam decode_param_; atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node prefill_node_prefixcache_; atb_speed::Model::Node decode_node_; atb::Tensor internal_tensor_; diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 25e3aaa66..91c02cca1 100755 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -153,6 +153,7 @@ struct RawForwardInput { uint32_t q_max_seq_len; std::vector seq_lens; std::vector q_seq_lens; + std::vector kv_cache_tokens_nums; std::vector new_token_slot_ids; std::vector> block_tables_vec; int32_t num_sequences; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index f9740d783..b0c8c81ea 100755 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -123,6 +123,7 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += get_vector_size(input.unique_token_lens_vec); total += get_vector_size(input.seq_lens); total += get_vector_size(input.q_seq_lens); + total += get_vector_size(input.kv_cache_tokens_nums); total += get_vector_size(input.new_token_slot_ids); total += get_vector_size(input.dp_global_token_nums); total += get_vector_size(input.embedding_ids); @@ -570,6 +571,7 @@ INLINE void deserialize_raw_forward_input( read_vector(buffer, input.unique_token_lens_vec); read_vector(buffer, input.seq_lens); read_vector(buffer, input.q_seq_lens); + read_vector(buffer, input.kv_cache_tokens_nums); read_vector(buffer, input.new_token_slot_ids); read_vector(buffer, input.dp_global_token_nums); read_vector(buffer, input.embedding_ids); @@ -629,6 +631,7 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input, write_vector(buffer, input.unique_token_lens_vec); write_vector(buffer, input.seq_lens); write_vector(buffer, input.q_seq_lens); + write_vector(buffer, input.kv_cache_tokens_nums); write_vector(buffer, input.new_token_slot_ids); write_vector(buffer, input.dp_global_token_nums); write_vector(buffer, input.embedding_ids); @@ -877,6 +880,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.new_cache_slots = torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options); + input_params.kv_cache_tokens_nums = + torch::tensor(std::move(raw_input.kv_cache_tokens_nums), tensor_options); input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(raw_input.block_tables_vec, 0); input_params.block_tables = diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 0eed88711..4bd6fb8dc 100755 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -41,6 +41,8 @@ limitations under the License. namespace xllm { +constexpr int32_t NZ_ALIGNMENT = 16; + namespace { uint32_t determine_micro_batches_num(const std::vector& batch) { bool not_all_in_decode = @@ -257,7 +259,16 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { const int64_t dtype_size = torch::scalarTypeToTypeMeta(dtype_).itemsize(); int64_t slot_size = 0; if (FLAGS_enable_mla) { - slot_size = dtype_size * (args_.kv_lora_rank() + args_.qk_rope_head_dim()); + if (FLAGS_enable_prefix_cache) { + slot_size = + dtype_size * + ((args_.kv_lora_rank() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT + + (args_.qk_rope_head_dim() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT); + } else { + slot_size = + dtype_size * (args_.kv_lora_rank() + args_.qk_rope_head_dim()); + } + } else { slot_size = 2 * dtype_size * head_dim_ * n_local_kv_heads_; } @@ -299,10 +310,23 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { std::vector> kv_cache_shape; kv_cache_shape.reserve(2); if (FLAGS_enable_mla) { - kv_cache_shape.emplace_back(std::vector{ - kv_cache_cap.n_blocks, block_size, 1, args_.kv_lora_rank()}); - kv_cache_shape.emplace_back(std::vector{ - kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()}); + if (FLAGS_enable_prefix_cache) { + kv_cache_shape.emplace_back( + std::vector{kv_cache_cap.n_blocks, + (args_.kv_lora_rank() + 15) / 16, + block_size, + 16}); + kv_cache_shape.emplace_back( + std::vector{kv_cache_cap.n_blocks, + (args_.qk_rope_head_dim() + 15) / 16, + block_size, + 16}); + } else { + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, 1, args_.kv_lora_rank()}); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()}); + } } else { #if defined(USE_NPU) kv_cache_shape.emplace_back(std::vector{ diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 58f6f5571..3a9449d95 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -63,6 +63,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector q_seq_lens = std::vector(pb_forward_input->q_seq_lens().begin(), pb_forward_input->q_seq_lens().end()); + std::vector kv_cache_tokens_nums = + std::vector(pb_forward_input->kv_cache_tokens_nums().begin(), + pb_forward_input->kv_cache_tokens_nums().end()); // aprint(q_seq_lens, "q_seq_lens", global_rank_); std::vector> block_tables_vec; for (size_t i = 0; i < pb_forward_input->block_tables_vec().size(); ++i) { @@ -212,6 +215,8 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options); input_params.q_seq_lens = torch::tensor(q_seq_lens, tensor_options); + input_params.kv_cache_tokens_nums = + torch::tensor(kv_cache_tokens_nums, tensor_options); input_params.kv_seq_lens_vec = std::move(seq_lens); input_params.q_seq_lens_vec = std::move(q_seq_lens); @@ -238,6 +243,13 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, torch::tensor(dst_block_indices, tensor_options); input_params.cum_sum = torch::tensor(cum_sum, tensor_options); + // input_params.ring_cur_seqlen_host = ring_cur_seqlen; + // input_params.ring_cache_seqlen_host = ring_cache_seqlen; + // input_params.history_compressed_kv = torch::tensor(history_compressed_kv, tensor_options); + // input_params.history_k_rope = torch::tensor(history_k_rope, tensor_options); + // input_params.ring_cur_seqlen = torch::tensor(ring_cur_seqlen, tensor_options); + // input_params.ring_cache_seqlen = torch::tensor(ring_cache_seqlen, tensor_options); + if (pb_forward_input->embeds().size() > 0) { const int32_t rows = pb_forward_input->embeds().size(); const int32_t cols = pb_forward_input->embeds()[0].vals().size(); @@ -397,8 +409,20 @@ void forward_input_to_proto(const RawForwardInput& inputs, pb_forward_input->set_max_seq_len(inputs.max_seq_len); pb_forward_input->set_q_max_seq_len(inputs.q_max_seq_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_seq_lens(), inputs.seq_lens); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_kv_cache_tokens_nums(), + inputs.kv_cache_tokens_nums); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_q_seq_lens(), inputs.q_seq_lens); + + // ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_history_compressed_kv(), + // inputs.history_compressed_kv); + // ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_history_k_rope(), + // inputs.history_k_rope); + // ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_ring_cur_seqlen(), + // inputs.ring_cur_seqlen); + // ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_ring_cache_seqlen(), + // inputs.ring_cache_seqlen); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_new_token_slot_ids(), inputs.new_token_slot_ids); pb_forward_input->mutable_block_tables_vec()->Reserve( diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index f1275ee89..cf66bb49f 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -48,6 +48,8 @@ limitations under the License. namespace xllm { constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024; +constexpr int32_t FORMAT_ND = 2; +constexpr int32_t FORMAT_NZ = 29; WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args, const torch::Device& device, @@ -87,12 +89,14 @@ bool WorkerImpl::allocate_kv_cache( for (int64_t i = 0; i < num_layers; ++i) { torch::Tensor key_cache, value_cache; #if defined(USE_NPU) + int32_t npu_format_type = + FLAGS_enable_mla && FLAGS_enable_prefix_cache ? FORMAT_NZ : FORMAT_ND; key_cache = at_npu::native::npu_format_cast( torch::empty(kv_cache_shape[0], torch::dtype(dtype_).device(device_)), - 2); + npu_format_type); value_cache = at_npu::native::npu_format_cast( torch::empty(kv_cache_shape[1], torch::dtype(dtype_).device(device_)), - 2); + npu_format_type); #elif defined(USE_MLU) key_cache = torch::empty(kv_cache_shape[0], torch::dtype(dtype_).device(device_)); @@ -448,6 +452,10 @@ void WorkerImpl::prepare_work_before_execute( // expert_load_data_.fill_(0); fwd_inputs_on_device.input_params.expert_load_data = expert_load_data_; } + //deepseek use prefix cache + if (FLAGS_enable_mla && input_params.batch_forward_type.is_chunked_prefill()){ + prepare_mla_prefixcache_inputs(input_params); + } } #endif processed_inputs.micro_inputs.push_back(std::move(fwd_inputs_on_device)); @@ -759,4 +767,32 @@ int64_t WorkerImpl::get_active_activation_memory() { .active_activation_memory; } +void WorkerImpl::prepare_mla_prefixcache_inputs(ModelInputParams& input_params){ + int32_t sum_prefix = input_params.kv_cache_tokens_nums.sum().item(); + input_params.history_compressed_kv = torch::empty( + {sum_prefix, context_.get_model_args().kv_lora_rank()}, + torch::TensorOptions().dtype(dtype_).pinned_memory(true)).to(device_); + + input_params.history_k_rope = torch::empty( + {sum_prefix, context_.get_model_args().qk_rope_head_dim()}, + torch::TensorOptions().dtype(dtype_).pinned_memory(true)).to(device_);; + + input_params.ring_cur_seqlen = + torch::stack({input_params.q_seq_lens, input_params.q_seq_lens}).to(device_); + + input_params.ring_cache_seqlen = torch::stack( + {input_params.q_seq_lens, input_params.kv_cache_tokens_nums.to(device_)}).to(device_); + + torch::Tensor ring_cur_seqlen_host = input_params.ring_cur_seqlen.cpu().contiguous(); + torch::Tensor ring_cache_seqlen_host = input_params.ring_cache_seqlen.cpu().contiguous(); + input_params.ring_cur_seqlen_host = + std::vector(ring_cur_seqlen_host.data_ptr(), + ring_cur_seqlen_host.data_ptr() + + ring_cur_seqlen_host.numel()); + input_params.ring_cache_seqlen_host = + std::vector(ring_cache_seqlen_host.data_ptr(), + ring_cache_seqlen_host.data_ptr() + + ring_cache_seqlen_host.numel()); +} + } // namespace xllm diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 63b1560e0..41650d9de 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -186,6 +186,7 @@ class WorkerImpl { private: void update_last_step_output(const std::optional& output); + void prepare_mla_prefixcache_inputs(ModelInputParams& input_params); protected: // runtime options diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 0182b6877..162476842 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -80,7 +80,7 @@ torch::ScalarType parse_dtype(const std::string& dtype_str, } if ((boost::iequals(dtype_str, "float") || boost::iequals(dtype_str, "float32"))) { - return torch::kFloat16; + return torch::kFloat32; } if (dtype_str.empty() || boost::iequals(dtype_str, "auto")) { @@ -89,6 +89,25 @@ torch::ScalarType parse_dtype(const std::string& dtype_str, CHECK(false) << "Unsupported dtype: " << dtype_str << " on device " << device; } +torch::ScalarType parse_dtype(const std::string& dtype_str) { + if (boost::iequals(dtype_str, "half") || + boost::iequals(dtype_str, "float16")) { + return torch::kFloat16; + } + if (boost::iequals(dtype_str, "bfloat16")) { + return torch::kBFloat16; + } + if ((boost::iequals(dtype_str, "float") || + boost::iequals(dtype_str, "float32"))) { + return torch::kFloat32; + } + + if (dtype_str.empty() || boost::iequals(dtype_str, "auto")) { + return torch::kFloat16; + } + CHECK(false) << "Unsupported dtype: " << dtype_str; +} + std::optional> parse_batch_sizes( const std::string& batch_sizes_str) { if (batch_sizes_str.empty() || batch_sizes_str == "auto") { diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 51491972d..c24ab7073 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -46,6 +46,8 @@ void pad_2d_vector(std::vector>& vec, T pad_value) { torch::ScalarType parse_dtype(const std::string& dtype_str, const torch::Device& device); +torch::ScalarType parse_dtype(const std::string& dtype_str); + std::optional> parse_batch_sizes( const std::string& batch_sizes_str); diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 010993a4b..6bebae450 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -184,10 +184,11 @@ class DeepseekV2ModelImpl : public torch::nn::Module { auto sin_pos = cos_sin_chunks[1].contiguous(); torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || - input_params[i].global_empty_kv_cache) { + if (input_params[0].batch_forward_type.is_chunked_prefill()) { + attn_mask = attn_mask_.get_attn_mask(512, dtype_, device_); + } else if (input_params[0].batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); - } else { + } else if (num_speculative_tokens_ > 0) { attn_mask = attn_mask_.gen_free_mask( num_speculative_tokens_ + 1, dtype_, device_); } diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index e67f84c0b..f07d76020 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -194,6 +194,7 @@ message ForwardInput { // beam search kernel repeated float acc_logprob_vec = 37; int32 batch_forward_type = 38; + repeated int32 kv_cache_tokens_nums = 39; } message BatchedForwardInputs {