From 1f40ca8c8a5b30d1a6852b0453fdd02dbcf581f4 Mon Sep 17 00:00:00 2001 From: DragonFive Date: Tue, 23 Dec 2025 21:18:43 +0800 Subject: [PATCH] feat: add rec_type and onerec batch input builder. --- .../rec_completion_service_impl.cpp | 21 +- xllm/core/common/types.h | 5 + xllm/core/distributed_runtime/rec_master.cpp | 413 +++++++- xllm/core/distributed_runtime/rec_master.h | 41 +- xllm/core/framework/batch/CMakeLists.txt | 4 + xllm/core/framework/batch/batch.cpp | 54 +- xllm/core/framework/batch/batch.h | 9 +- xllm/core/framework/batch/batch_factory.cpp | 58 ++ xllm/core/framework/batch/batch_factory.h | 7 + .../batch/onerec_batch_input_builder.cpp | 957 ++++++++++++++++++ .../batch/onerec_batch_input_builder.h | 119 +++ .../batch/rec_batch_input_builder.cpp | 58 ++ .../framework/batch/rec_batch_input_builder.h | 53 + .../core/framework/model/model_input_params.h | 138 +++ xllm/core/framework/request/CMakeLists.txt | 1 + xllm/core/framework/request/rec_type.h | 28 + xllm/core/framework/request/request.cpp | 2 + xllm/core/framework/request/request_state.h | 5 + xllm/core/framework/request/sequence.cpp | 123 ++- xllm/core/framework/request/sequence.h | 57 ++ xllm/core/scheduler/fixed_steps_scheduler.cpp | 22 +- 21 files changed, 2074 insertions(+), 101 deletions(-) create mode 100644 xllm/core/framework/batch/onerec_batch_input_builder.cpp create mode 100644 xllm/core/framework/batch/onerec_batch_input_builder.h create mode 100644 xllm/core/framework/batch/rec_batch_input_builder.cpp create mode 100644 xllm/core/framework/batch/rec_batch_input_builder.h create mode 100644 xllm/core/framework/request/rec_type.h diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp index 9ffc4261b..f09ed0367 100644 --- a/xllm/api_service/rec_completion_service_impl.cpp +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -28,9 +28,7 @@ limitations under the License. #include "completion.pb.h" #include "core/distributed_runtime/llm_master.h" #include "core/distributed_runtime/rec_master.h" -#include "core/framework/request/mm_data.h" #include "core/framework/request/request_output.h" -#include "core/util/utils.h" #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) @@ -167,18 +165,15 @@ void RecCompletionServiceImpl::process_async_impl( } const auto& rpc_request_ref = call->request(); - std::optional mm_data = std::nullopt; + std::optional> input_tensors = + std::nullopt; if (rpc_request_ref.input_tensors_size()) { - // HISTOGRAM_OBSERVE(rec_input_first_dim, - // rpc_request_ref.input_tensors(0).shape(0)); - - MMDict mm_dict; + std::vector tensors; + tensors.reserve(rpc_request_ref.input_tensors_size()); for (int i = 0; i < rpc_request_ref.input_tensors_size(); ++i) { - const auto& tensor = rpc_request_ref.input_tensors(i); - mm_dict[tensor.name()] = - xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + tensors.push_back(rpc_request_ref.input_tensors(i)); } - mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict)); + input_tensors = std::move(tensors); } // schedule the request @@ -187,7 +182,7 @@ void RecCompletionServiceImpl::process_async_impl( master_->handle_request( std::move(rpc_request_ref.prompt()), std::move(prompt_tokens), - std::move(mm_data), + std::move(input_tensors), std::move(request_params), [call, model, @@ -219,4 +214,4 @@ void RecCompletionServiceImpl::process_async_impl( }); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 4f0dbb14b..8712e7a82 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -292,4 +292,9 @@ struct EplbInfo { inline constexpr int REC_TOKEN_SIZE = 3; using RecTokenTriple = std::array; + +inline constexpr const char* LLM_REC_INPUT_TOKENS = "llm_rec_input_tokens"; +inline constexpr const char* LLM_REC_INPUT_INDICES = "llm_rec_input_indices"; +inline constexpr const char* LLM_REC_INPUT_EMBEDDING = + "llm_rec_input_embedding"; } // namespace xllm diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp index 0bb10a901..9f53ed055 100644 --- a/xllm/core/distributed_runtime/rec_master.cpp +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -19,9 +19,14 @@ limitations under the License. #include #include #include +#include + +#include #include "common/macros.h" #include "common/metrics.h" +#include "common/types.h" +#include "framework/request/mm_data.h" #include "models/model_registry.h" #include "rec_engine.h" #include "runtime/xservice_client.h" @@ -32,6 +37,182 @@ limitations under the License. namespace xllm { +namespace { + +constexpr int32_t kDefaultPlaceholderToken = 20152019; + +RecType get_rec_type(const ModelArgs& model_args) { + const auto& model_type = model_args.model_type(); + if (model_type == "onerec") { + return RecType::kOneRec; + } + if (model_type == "qwen2" || model_type == "qwen3") { + return RecType::kLlmRec; + } + return RecType::kNone; +} + +bool process_onerec_inputs( + const std::optional>& prompt_tokens, + const std::optional>& input_tensors, + std::vector* local_prompt_tokens, + MMData* processed_mm_data, + OutputCallback callback) { + if (prompt_tokens.has_value() && input_tensors.has_value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "prompt_tokens and input_tensors cannot both be set"); + return false; + } + + if (prompt_tokens.has_value()) { + local_prompt_tokens->assign(prompt_tokens.value().begin(), + prompt_tokens.value().end()); + } + + if (input_tensors.has_value()) { + MMDict mm_dict; + mm_dict.reserve(input_tensors->size()); + for (const auto& tensor : input_tensors.value()) { + mm_dict[tensor.name()] = + util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + } + *processed_mm_data = MMData(MMType::EMBEDDING, mm_dict); + } + + if (local_prompt_tokens->empty() && !processed_mm_data->valid()) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Rec model requires prompt_tokens or input_tensors to be provided"); + return false; + } + + return true; +} + +bool process_llmrec_raw_inputs( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const ModelArgs& model_args, + std::vector* local_prompt_tokens, + MMData* processed_mm_data, + OutputCallback callback) { + std::vector local_input_tokens; + std::vector local_input_indices; + torch::Tensor input_tokens_tensor; + torch::Tensor input_indices_tensor; + torch::Tensor input_embedding_tensor; + int64_t embedding_rows = 0; + + if (input_tokens.has_value()) { + const auto& tokens = input_tokens.value(); + local_input_tokens.reserve(tokens.size()); + for (const auto token : tokens) { + local_input_tokens.push_back(static_cast(token)); + } + if (!local_input_tokens.empty()) { + input_tokens_tensor = + torch::from_blob(local_input_tokens.data(), + {static_cast(local_input_tokens.size())}, + torch::dtype(torch::kInt32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_TOKENS, input_tokens_tensor); + local_prompt_tokens->assign(local_input_tokens.begin(), + local_input_tokens.end()); + } + } + + if (input_indices.has_value()) { + if (!input_tokens.has_value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices require input tokens"); + return false; + } + const auto& indices = input_indices.value(); + local_input_indices.reserve(indices.size()); + for (const auto index : indices) { + local_input_indices.push_back(static_cast(index)); + } + if (local_input_indices.size() != local_input_tokens.size()) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "LLMRec input indices size does not match input tokens"); + return false; + } + if (!local_input_indices.empty()) { + input_indices_tensor = + torch::from_blob(local_input_indices.data(), + {static_cast(local_input_indices.size())}, + torch::dtype(torch::kInt32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_INDICES, input_indices_tensor); + } + } + + if (input_embedding.has_value()) { + const auto& embedding_vec = input_embedding.value(); + if (embedding_vec.empty()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input embedding is empty"); + return false; + } + const int64_t rows = static_cast(embedding_vec.size()); + const int64_t cols = static_cast(embedding_vec[0].size()); + if (cols != model_args.hidden_size()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input embedding has invalid hidden size"); + return false; + } + + std::vector flat_data; + flat_data.reserve(static_cast(rows * cols)); + for (const auto& row : embedding_vec) { + flat_data.insert(flat_data.end(), row.begin(), row.end()); + } + input_embedding_tensor = + torch::from_blob(flat_data.data(), + {rows, cols}, + torch::dtype(torch::kFloat32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_EMBEDDING, input_embedding_tensor); + embedding_rows = rows; + local_prompt_tokens->insert(local_prompt_tokens->end(), + static_cast(embedding_rows), + kDefaultPlaceholderToken); + } + + if (!local_input_indices.empty()) { + const int64_t total_size = + static_cast(local_input_tokens.size()) + embedding_rows; + std::unordered_set seen; + seen.reserve(local_input_indices.size()); + for (const auto index : local_input_indices) { + if (index < 0 || index >= total_size) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices contain invalid values"); + return false; + } + if (!seen.insert(index).second) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices contain duplicate values"); + return false; + } + } + } + + if (local_prompt_tokens->empty()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is empty"); + return false; + } + + return true; +} + +} // namespace + RecMaster::RecMaster(const Options& options) : Master(options, EngineType::REC) { // Initialize with Rec engine type @@ -39,6 +220,10 @@ RecMaster::RecMaster(const Options& options) CHECK(engine_->init()); model_args_ = engine_->model_args(); + rec_type_ = get_rec_type(model_args_); + if (rec_type_ == RecType::kNone) { + LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type(); + } bool enable_decode_response_to_service = false; if (options_.enable_service_routing()) { @@ -71,7 +256,6 @@ RecMaster::RecMaster(const Options& options) .enable_decode_response_to_service(enable_decode_response_to_service); scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options); - // OmniRec model does not have a tokenizer chat_template_ = nullptr; tokenizer_ = nullptr; threadpool_ = @@ -108,39 +292,83 @@ RecMaster::~RecMaster() { void RecMaster::handle_request(std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, + std::optional> + input_tensors, RequestParams sp, OutputCallback callback) { - // add one pending request + if (rec_type_ != RecType::kOneRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "OneRec should use onerec input interface"); + return; + } + schedule_request( + std::move(sp), + std::move(callback), + [this, + prompt = std::move(prompt), + prompt_tokens = std::move(prompt_tokens), + input_tensors = std::move(input_tensors)]( + const RequestParams& params, + OutputCallback cb) mutable { + return generate_request(std::move(prompt), + std::move(prompt_tokens), + std::move(input_tensors), + params, + std::move(cb)); + }); +} + +void RecMaster::handle_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + RequestParams sp, + OutputCallback callback) { + if (rec_type_ != RecType::kLlmRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec should use raw input interface"); + return; + } + schedule_request( + std::move(sp), + std::move(callback), + [this, + input_tokens = std::move(input_tokens), + input_indices = std::move(input_indices), + input_embedding = std::move(input_embedding)]( + const RequestParams& params, + OutputCallback cb) mutable { + return generate_request(std::move(input_tokens), + std::move(input_indices), + std::move(input_embedding), + params, + std::move(cb)); + }); +} + +void RecMaster::schedule_request(RequestParams sp, + OutputCallback callback, + RequestBuilder build_request) { scheduler_->incr_pending_requests(1); auto cb = [callback = std::move(callback), scheduler = scheduler_.get()](const RequestOutput& output) { output.log_request_status(); return callback(output); }; - // add into the queue threadpool_->schedule([this, - prompt = std::move(prompt), - prompt_tokens = std::move(prompt_tokens), - mm_data = std::move(mm_data), sp = std::move(sp), - callback = std::move(cb)]() mutable { + callback = std::move(cb), + build_request = std::move(build_request)]() mutable { AUTO_COUNTER(request_handling_latency_seconds_completion); - // remove the pending request after scheduling SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); }); Timer timer; - // verify the prompt if (!sp.verify_params(callback)) { return; } - auto request = generate_request(std::move(prompt), - std::move(prompt_tokens), - std::move(mm_data), - sp, - callback); + auto request = build_request(sp, std::move(callback)); if (!request) { return; } @@ -155,40 +383,97 @@ void RecMaster::handle_request(std::string prompt, std::shared_ptr RecMaster::generate_request( std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, - RequestParams sp, + std::optional> input_tensors, + const RequestParams& sp, OutputCallback callback) { // For Rec model, prompt is expected to be empty and prompt_tokens should // contain the actual data Skip prompt empty check as mentioned in // requirements - Timer timer; - std::vector local_prompt_tokens; - - if (prompt_tokens.has_value()) { - local_prompt_tokens = std::move(prompt_tokens.value()); - LOG(INFO) - << "[Rec DEBUG] generate_request - received prompt_tokens.size(): " - << local_prompt_tokens.size() - << ", prompt.length(): " << prompt.length(); - } else if (!mm_data.has_value()) { - // sparse LLM - LOG(ERROR) << "Rec model requires prompt_tokens/embedding to be provided"; + if (rec_type_ == RecType::kNone) { + LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type(); CALLBACK_WITH_ERROR( StatusCode::INVALID_ARGUMENT, - "Rec model requires prompt_tokens/embedding to be provided"); + std::string("Unsupported rec model_type: ") + model_args_.model_type()); + return nullptr; + } + + Timer timer; + std::vector local_prompt_tokens; + MMData processed_mm_data; + bool processed_ok = false; + + if (rec_type_ == RecType::kOneRec) { + processed_ok = process_onerec_inputs(prompt_tokens, + input_tensors, + &local_prompt_tokens, + &processed_mm_data, + callback); + } + + if (!processed_ok) { return nullptr; } COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + return build_request_common(std::move(prompt), + std::move(local_prompt_tokens), + std::move(processed_mm_data), + sp, + callback, + rec_type_ == RecType::kLlmRec); +} + +std::shared_ptr RecMaster::generate_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const RequestParams& sp, + OutputCallback callback) { + if (rec_type_ != RecType::kLlmRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec inputs require rec_type kLlmRec"); + return nullptr; + } + + Timer timer; + std::vector local_prompt_tokens; + MMData processed_mm_data; + if (!process_llmrec_raw_inputs(std::move(input_tokens), + std::move(input_indices), + std::move(input_embedding), + model_args_, + &local_prompt_tokens, + &processed_mm_data, + callback)) { + return nullptr; + } + + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + return build_request_common(std::string(""), + std::move(local_prompt_tokens), + std::move(processed_mm_data), + sp, + callback, + true); +} + +std::shared_ptr RecMaster::build_request_common( + std::string prompt, + std::vector prompt_tokens, + MMData mm_data, + const RequestParams& sp, + OutputCallback callback, + bool build_stop_checker) { int32_t max_context_len = model_args_.max_position_embeddings(); if (!options_.enable_chunked_prefill()) { max_context_len = std::min(max_context_len, options_.max_tokens_per_batch()); } - if (local_prompt_tokens.size() >= max_context_len) { - LOG(ERROR) << "Prompt is too long: " << local_prompt_tokens.size(); + if (prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << prompt_tokens.size(); CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); return nullptr; } @@ -199,9 +484,7 @@ std::shared_ptr RecMaster::generate_request( max_tokens = kDefaultMaxTokens; } - // allocate enough capacity for prompt tokens, max tokens, and speculative - // tokens - size_t capacity = local_prompt_tokens.size() + max_tokens + + size_t capacity = prompt_tokens.size() + max_tokens + options_.num_speculative_tokens() + /*bonus_token*/ 1; if (options_.enable_schedule_overlap()) { capacity += options_.num_speculative_tokens() + 1; @@ -220,30 +503,55 @@ std::shared_ptr RecMaster::generate_request( sampling_param.is_embeddings = sp.is_embeddings; sampling_param.beam_width = sp.beam_width; if (best_of > sp.n) { - // enable logprobs for best_of to generate sequence logprob sampling_param.logprobs = true; } - // sampling_param.do_sample = sp.do_sample; bool stream = sp.streaming; - // results cannot be streamed when best_of != n if (best_of != sp.n) { stream = false; } - // std::unordered_set stop_tokens; - // std::vector> stop_sequences; - // StoppingChecker stopping_checker( - // max_tokens, - // max_context_len - options_.num_speculative_tokens(), - // , - // model_args_.eos_token_id(), - // sp.ignore_eos, - // std::move(stop_tokens), - // std::move(stop_sequences)); + StoppingChecker stopping_checker; + if (build_stop_checker) { + std::unordered_set stop_tokens; + if (sp.stop_token_ids.has_value()) { + const auto& stop_token_ids = sp.stop_token_ids.value(); + stop_tokens.insert(stop_token_ids.begin(), stop_token_ids.end()); + } else { + stop_tokens = model_args_.stop_token_ids(); + } + + std::vector> stop_sequences; + if (sp.stop.has_value()) { + if (!tokenizer_) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Tokenizer is required for stop sequences"); + return nullptr; + } + for (const auto& s : sp.stop.value()) { + std::vector tmp_tokens; + if (!tokenizer_->encode(s, &tmp_tokens)) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Failed to encode stop sequence"); + LOG(ERROR) << "Failed to encode stop sequence: " << s; + return nullptr; + } + stop_sequences.push_back(std::move(tmp_tokens)); + } + } + + stopping_checker = StoppingChecker( + max_tokens, + max_context_len - options_.num_speculative_tokens(), + model_args_.eos_token_id(), + sp.ignore_eos, + std::move(stop_tokens), + std::move(stop_sequences)); + } + RequestState req_state(std::move(prompt), - std::move(local_prompt_tokens), - mm_data.value_or(MMData{}), + std::move(prompt_tokens), + std::move(mm_data), std::move(sampling_param), std::move(stopping_checker), capacity, @@ -257,9 +565,8 @@ std::shared_ptr RecMaster::generate_request( callback, nullptr, sp.decode_address); - // TODO. add following when next pr (add is_rec_model and bos_token_id to - // RequestState). req_state.is_rec_model = true; req_state.bos_token_id = - // model_args_.bos_token_id(); + req_state.rec_type = rec_type_; + req_state.bos_token_id = model_args_.bos_token_id(); auto request = std::make_shared(sp.request_id, sp.x_request_id, sp.x_request_time, diff --git a/xllm/core/distributed_runtime/rec_master.h b/xllm/core/distributed_runtime/rec_master.h index 0ed5b76d3..a8eeda55e 100644 --- a/xllm/core/distributed_runtime/rec_master.h +++ b/xllm/core/distributed_runtime/rec_master.h @@ -16,12 +16,15 @@ limitations under the License. #pragma once #include +#include #include #include "framework/chat_template/jinja_chat_template.h" #include "framework/model/model_args.h" +#include "framework/request/rec_type.h" #include "master.h" #include "rec_engine.h" +#include "rec.pb.h" #include "scheduler/continuous_scheduler.h" #include "scheduler/fixed_steps_scheduler.h" #include "util/threadpool.h" @@ -37,24 +40,56 @@ class RecMaster : public Master { // completion/encode void handle_request(std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, + std::optional> + input_tensors, RequestParams sp, OutputCallback callback); + void handle_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + RequestParams sp, + OutputCallback callback); + // start the handling loop void run() override; private: + using RequestBuilder = + std::function(const RequestParams&, + OutputCallback)>; + + void schedule_request(RequestParams sp, + OutputCallback callback, + RequestBuilder build_request); + std::shared_ptr generate_request( std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, - RequestParams sp, + std::optional> input_tensors, + const RequestParams& sp, OutputCallback callback); + std::shared_ptr generate_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const RequestParams& sp, + OutputCallback callback); + + std::shared_ptr build_request_common( + std::string prompt, + std::vector prompt_tokens, + MMData mm_data, + const RequestParams& sp, + OutputCallback callback, + bool build_stop_checker); + std::unique_ptr scheduler_; // model args ModelArgs model_args_; + RecType rec_type_ = RecType::kNone; std::unique_ptr threadpool_; std::unique_ptr tokenizer_; // chat template instance diff --git a/xllm/core/framework/batch/CMakeLists.txt b/xllm/core/framework/batch/CMakeLists.txt index 94d202400..c0be24ce7 100644 --- a/xllm/core/framework/batch/CMakeLists.txt +++ b/xllm/core/framework/batch/CMakeLists.txt @@ -10,12 +10,16 @@ cc_library( batch.h batch_factory.h batch_input_builder.h + rec_batch_input_builder.h + onerec_batch_input_builder.h mposition.h SRCS dit_batch.cpp batch.cpp batch_factory.cpp batch_input_builder.cpp + rec_batch_input_builder.cpp + onerec_batch_input_builder.cpp mposition.cpp beam_search.h DEPS diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index 5699cc818..d531b4f6e 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/request/sequence.h" #include "framework/sampling/sampling_params.h" +#include "rec_batch_input_builder.h" #include "runtime/params_utils.h" #include "util/slice.h" #include "util/tensor_helper.h" @@ -96,6 +97,10 @@ void Batch::add(const std::vector& sequences) { ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size, const ModelArgs& args) { + if (sequences_.empty() && !sequence_groups_.empty()) { + return prepare_rec_forward_input( + num_decoding_tokens, min_decoding_batch_size, args); + } BatchInputBuilder builder(sequences_, allowed_max_tokens_, input_embeddings_vec_, @@ -108,6 +113,43 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, min_decoding_batch_size); } +ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool) { + RecType rec_type = RecType::kNone; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + rec_type = sequence_groups_[0]->sequences()[0]->rec_type(); + } + + auto builder = RecBatchInputBuilder::create(rec_type, + sequence_groups_, + allowed_max_tokens_, + input_embeddings_vec_, + mm_data_vec_, + swap_block_transfer_infos_, + batch_id_, + &args, + thread_pool); + return builder->build_rec_forward_input(num_decoding_tokens, + min_decoding_batch_size); +} + +std::vector Batch::get_sequences() const { + if (!sequences_.empty()) { + return sequences_; + } + + std::vector result; + for (const auto* seq_group : sequence_groups_) { + const auto& sequences = seq_group->sequences(); + for (const auto& seq_ptr : sequences) { + result.push_back(seq_ptr.get()); + } + } + return result; +} + void Batch::dp_balance_shuffle_seqs() { // this shuffle operation is mainly used for npu with 24 cores // and specific mla op implementation @@ -217,7 +259,8 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output, // this means all sequences are in prefill stage status. const int64_t num_seqs = raw_output.outputs.size(); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { if (seq->finished()) { output_idx++; continue; @@ -264,7 +307,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, if (sample_output.embeddings.defined()) { const int64_t num_seqs = sample_output.embeddings.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { CHECK_LT(output_idx, num_seqs); auto cur_seq_embed = safe_to(sample_output.embeddings[output_idx++], torch::kFloat32); @@ -277,7 +321,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, // this means all sequences are in prefill stage status. const int64_t num_seqs = sample_output.next_tokens.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { if (seq->finished()) { output_idx++; continue; @@ -352,7 +397,8 @@ void Batch::process_embedding_output(const torch::Tensor& output_embedding) { Token token(0); if (output_embedding.defined()) { int32_t slice_img_index = 0; - for (auto* seq : sequences_) { // TODO + const auto sequences = get_sequences(); + for (auto* seq : sequences) { const auto& mm_data = seq->get_mm_data(); auto pixel_values = mm_data.get_tensor_vec("pixel_values"); diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index 31c409f82..0800231c3 100755 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -75,7 +75,7 @@ class Batch { // get the number of sequences in the batch size_t size() const { return sequences_.size(); } - bool empty() const { return sequences_.empty(); } + bool empty() const { return sequences_.empty() && sequence_groups_.empty(); } Sequence* operator[](size_t i) { return sequences_[i]; } @@ -84,6 +84,11 @@ class Batch { uint32_t min_decoding_bach_size, const ModelArgs& args); + ForwardInput prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool = nullptr); + // Convert Batch to pb type, which will be pass to remote worker. RawForwardInput prepare_forward_input(const ModelArgs& args, ThreadPool* thread_pool); @@ -138,6 +143,8 @@ class Batch { void dp_balance_shuffle_seqs(); + std::vector get_sequences() const; + std::vector sequences_; std::vector sequence_groups_; std::vector* swap_block_transfer_infos_ = nullptr; diff --git a/xllm/core/framework/batch/batch_factory.cpp b/xllm/core/framework/batch/batch_factory.cpp index da0c9b55b..7e0e78973 100644 --- a/xllm/core/framework/batch/batch_factory.cpp +++ b/xllm/core/framework/batch/batch_factory.cpp @@ -92,4 +92,62 @@ std::vector BatchFactory::create_batches( return batches; } +std::vector BatchFactory::create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + std::vector>* swap_block_transfer_infos) { + size_t num_prompt_tokens = 0; + size_t num_generated_tokens = 0; + std::vector batches(dp_size_); + for (size_t i = 0; i < running_sequences.size(); ++i) { + auto* sequence = running_sequences[i]; + const size_t token_budget = running_sequences_budgets[i]; + + const size_t remaining_prompt_tokens = + sequence->num_prompt_tokens() > + sequence->kv_state().kv_cache_tokens_num() + ? sequence->num_prompt_tokens() - + sequence->kv_state().kv_cache_tokens_num() + : 0; + const size_t prompt_tokens = + std::min(remaining_prompt_tokens, token_budget); + const size_t generated_tokens = token_budget - prompt_tokens; + num_prompt_tokens += prompt_tokens; + num_generated_tokens += generated_tokens; + + batches[sequence->dp_rank()].set_batch_id(); + } + + for (const auto& request : running_requests) { + auto seq_group = request->sequence_group(); + int32_t dp_rank = seq_group->dp_rank(); + batches[dp_rank].add(seq_group); + } + + for (int i = 0; i < dp_size_; i++) { + if (!batches[i].empty()) { + if (swap_block_transfer_infos != nullptr && + swap_block_transfer_infos->size() == dp_size_) { + batches[i].set_swap_block_transfer_infos( + &(swap_block_transfer_infos->at(i))); + } + } + } + + COUNTER_ADD(num_processing_tokens_total_prompt, num_prompt_tokens); + COUNTER_ADD(num_processing_tokens_total_generated, num_generated_tokens); + + if (running_sequences.size() > 0) { + HISTOGRAM_OBSERVE( + num_prompt_tokens_per_request, + static_cast(num_prompt_tokens / running_sequences.size())); + HISTOGRAM_OBSERVE( + num_generated_tokens_per_request, + static_cast(num_generated_tokens / running_sequences.size())); + } + + return batches; +} + } // namespace xllm diff --git a/xllm/core/framework/batch/batch_factory.h b/xllm/core/framework/batch/batch_factory.h index bd7d7a084..2db720bd8 100644 --- a/xllm/core/framework/batch/batch_factory.h +++ b/xllm/core/framework/batch/batch_factory.h @@ -35,6 +35,13 @@ class BatchFactory { std::vector>* swap_block_transfer_infos = nullptr); + std::vector create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + std::vector>* swap_block_transfer_infos = + nullptr); + private: BatchFactory(int32_t dp_size) : dp_size_(dp_size) {} ~BatchFactory() = default; diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_batch_input_builder.cpp new file mode 100644 index 000000000..782595a60 --- /dev/null +++ b/xllm/core/framework/batch/onerec_batch_input_builder.cpp @@ -0,0 +1,957 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "onerec_batch_input_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "framework/model/model_input_params.h" +#include "framework/request/sequence.h" +#include "framework/sampling/sampling_params.h" +#include "util/tensor_helper.h" +#include "util/utils.h" + +namespace xllm { + +// Use Meyers' Singleton pattern to avoid static initialization order fiasco +// This ensures the cache is initialized on first use, after all dependencies +// (like PyTorch runtime) are properly initialized. +OneRecBatchInputBuilder::HighPerformanceCache& +OneRecBatchInputBuilder::get_perf_cache() { + static HighPerformanceCache cache; + cache.ensure_tensors_initialized(); + return cache; +} + +OneRecBatchInputBuilder::OneRecBatchInputBuilder( + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool) + : sequence_groups_(sequence_groups), + allowed_max_tokens_(allowed_max_tokens), + input_embeddings_vec_(input_embeddings_vec), + mm_data_vec_(mm_data_vec), + swap_block_transfer_infos_(swap_block_transfer_infos), + batch_id_(batch_id), + args_(args), + thread_pool_(thread_pool) { + // Get references to function-local statics (safe initialization) + auto& perf_cache = get_perf_cache(); + perf_cache.memory_pool.reset(); +} + +ForwardInput OneRecBatchInputBuilder::build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) { + // Get reference to function-local static cache (safe initialization) + auto& perf_cache = get_perf_cache(); + + // ========== Global constant cache ========== + // Note: FIXED_POSITIONS is a simple vector, safe for static initialization + static const std::vector FIXED_POSITIONS = {0}; + // Note: FIXED_ENCODER_POSITIONS is now obtained from perf_cache to avoid + // static initialization order issues with torch::Tensor + + // ========== Fast sequence information extraction ========== + const int32_t num_sequences = + !sequence_groups_.empty() + ? std::accumulate(sequence_groups_.begin(), + sequence_groups_.end(), + 0, + [](int sum, const auto& group) { + return sum + group->sequences().size(); + }) + : 0; + const int32_t THREADPOOL_THRESHOLD = 16; + if (num_sequences == 0) { + return ForwardInput{}; + } + + // Get basic information of first sequence - optimize pointer access + Sequence* first_sequence = nullptr; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + first_sequence = sequence_groups_[0]->sequences()[0].get(); + } + + if (!first_sequence) { + return ForwardInput{}; + } + + const uint32_t seq_len = first_sequence->num_tokens(); + const uint32_t num_decoder_embeddings = + first_sequence->num_decoder_embeddings(); + const uint32_t n_prompt_tokens = first_sequence->num_prompt_tokens(); + const bool is_first_prefill = (first_sequence->num_generated_tokens() == 0); + // const uint64_t model_version = first_sequence->get_model_version(); + + // ========== High-performance encoder tokens construction ========== + auto build_encoder_tokens_optimized = [&]() + -> const std::vector& { + auto& cache_data = perf_cache.cache_data; + + // encoder doesn't use cache key, because encoder doesn't use encoder_tokens + // in non-first prefill scenarios, only uses encoder_seq_len + if (!is_first_prefill) { + return cache_data.encoder_tokens; + } + + // Optimization: Use SIMD-friendly memory access patterns + cache_data.encoder_tokens.clear(); + cache_data.encoder_seq_lens.clear(); + + // Optimization for scenarios where sequences have different lengths across + // sequence groups Pre-calculate total token count to avoid multiple memory + // reallocations + int32_t total_tokens = 0; + for (const auto& group_ptr : sequence_groups_) { + if (!group_ptr->sequences().empty()) { + // Sequences within group have same length, only need to get first + // sequence's length + const int32_t group_encoder_seq_len = + group_ptr->sequences()[0]->encoder_tokens().size(); + total_tokens += group_encoder_seq_len * group_ptr->sequences().size(); + } + } + + cache_data.encoder_tokens.reserve(total_tokens); + cache_data.encoder_seq_lens.resize(num_sequences); + cache_data.encoder_sparse_embeddings.clear(); + cache_data.encoder_sparse_embeddings.reserve(num_sequences); + cache_data.decoder_context_embeddings.clear(); + cache_data.decoder_context_embeddings.reserve(num_sequences); + + // Process by groups in batch + int32_t global_seq_idx = 0; + for (const auto& group_ptr : sequence_groups_) { + const auto& group = *group_ptr; + const int32_t group_size = group.sequences().size(); + + if (group_size == 0) continue; + + const int32_t group_encoder_seq_len = + group.sequences()[0]->encoder_seq_len(); + + // Batch set same values + std::fill_n(&cache_data.encoder_seq_lens[global_seq_idx], + group_size, + group_encoder_seq_len); + + // Batch copy tokens by sequence and collect sparse_embedding + for (const auto& sequence : group.sequences()) { + const auto& encoder_tokens = sequence->encoder_tokens(); + const int32_t* src_ptr = encoder_tokens.data(); + const int32_t group_encoder_seq_len = encoder_tokens.size(); + + // Use efficient batch insertion + if (group_encoder_seq_len > 0) { + cache_data.encoder_tokens.insert(cache_data.encoder_tokens.end(), + src_ptr, + src_ptr + group_encoder_seq_len); + } + // Collect sparse_embedding + auto mm_data = sequence->get_mm_data(); + auto sparse_embedding_optional = + mm_data.get(Sequence::ENCODER_SPARSE_EMBEDDING_NAME); + if (sparse_embedding_optional.has_value()) { + cache_data.encoder_sparse_embeddings.push_back( + sparse_embedding_optional.value()); + } + + auto decoder_context_embedding_optional = mm_data.get( + Sequence::DECODER_CONTEXT_EMBEDDING_NAME); + if (decoder_context_embedding_optional.has_value()) { + cache_data.decoder_context_embeddings.push_back( + decoder_context_embedding_optional.value()); + } + } + + global_seq_idx += group_size; + } + + return cache_data.encoder_tokens; + }; + + // ========== High-performance decoder data construction ========== + auto build_decoder_data_optimized = [&]() { + // Pre-allocate all containers to avoid dynamic expansion + const size_t total_tokens = num_sequences * seq_len; + std::vector flatten_tokens_vec; + flatten_tokens_vec.reserve(total_tokens); + std::vector sampling_params; + sampling_params.reserve(num_sequences); + std::vector selected_token_idxes; + selected_token_idxes.reserve(num_sequences); + std::vector sample_idxes; + sample_idxes.reserve(num_sequences); + std::vector> generated_tokens; + generated_tokens.reserve(num_sequences); + + // Multi-threading optimization: Use parallel processing when sequence count + // exceeds threshold and thread pool is available + ThreadPool* threadpool = thread_pool_; + if (num_sequences >= THREADPOOL_THRESHOLD && threadpool != nullptr) { + // Thread-safe result containers + std::vector> thread_flatten_tokens(num_sequences); + std::vector thread_sampling_params( + num_sequences); + std::vector thread_selected_token_idxes(num_sequences); + std::vector thread_sample_idxes(num_sequences); + std::vector> thread_generated_tokens(num_sequences); + + // Calculate thread allocation + const size_t num_threads = + std::min(static_cast(num_sequences), static_cast(16)); + const size_t sequences_per_thread = + (num_sequences + num_threads - 1) / num_threads; + + std::vector> futures; + std::vector>> promises; + futures.reserve(num_threads); + promises.reserve(num_threads); + + // Parallel processing function + auto process_sequences_range = [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; + i < end_idx && i < static_cast(num_sequences); + ++i) { + const Sequence* sequence = nullptr; + // Get sequence from sequence_groups + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + if (seq_idx + group->sequences().size() > i) { + sequence = group->sequences()[i - seq_idx].get(); + break; + } + seq_idx += group->sequences().size(); + } + + if (!sequence) continue; + + const auto& token_ids = sequence->tokens(); + + // Build generated tokens + auto& cur_generated_tokens = thread_generated_tokens[i]; + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + + // Build flatten tokens + auto& cur_flatten_tokens = thread_flatten_tokens[i]; + cur_flatten_tokens.reserve(seq_len); + cur_flatten_tokens.insert(cur_flatten_tokens.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Set sampling parameters + thread_sampling_params[i] = sequence->sampling_param(); + thread_sample_idxes[i] = static_cast(i); + } + }; + + // Launch parallel tasks + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_idx = thread_idx * sequences_per_thread; + size_t end_idx = std::min(start_idx + sequences_per_thread, + static_cast(num_sequences)); + + if (start_idx >= static_cast(num_sequences)) break; + + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + threadpool->schedule( + [process_sequences_range, start_idx, end_idx, promise]() mutable { + try { + process_sequences_range(start_idx, end_idx); + promise->set_value(); + } catch (...) { + promise->set_exception(std::current_exception()); + } + }); + } + + // Wait for all tasks to complete + for (auto& future : futures) { + future.get(); + } + + // Merge results + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + for (int32_t i = 0; i < num_sequences; ++i) { + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + thread_flatten_tokens[i].begin(), + thread_flatten_tokens[i].end()); + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(thread_sampling_params[i]); + sample_idxes.push_back(thread_sample_idxes[i]); + generated_tokens.push_back(std::move(thread_generated_tokens[i])); + } + } else { + // Original single-thread processing logic + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + for (const auto& sequence : group->sequences()) { + const auto& token_ids = sequence->tokens(); + + // Optimize generated tokens construction + auto& cur_generated_tokens = generated_tokens.emplace_back(); + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + // Optimize token processing - batch operations + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Simplify sampling parameter processing + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(sequence->sampling_param()); + sample_idxes.push_back(seq_idx); + seq_idx++; + } + } + } + + return std::make_tuple(std::move(flatten_tokens_vec), + std::move(sampling_params), + std::move(selected_token_idxes), + std::move(sample_idxes), + std::move(generated_tokens)); + }; + + // ========== Comprehensive parallel execution of optimized data construction + // ========== Use thread pool to execute all independent data construction + // tasks in parallel + std::future&> encoder_future; + std::future, + std::vector, + std::vector, + std::vector, + std::vector>>> + decoder_future; + + // Declare variables to store results + const std::vector* encoder_tokens_ptr = nullptr; + std::vector flatten_tokens_vec; + std::vector sampling_params; + std::vector selected_token_idxes; + std::vector sample_idxes; + std::vector> generated_tokens; + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Use ThreadPool's schedule method to execute independent tasks in parallel + // build_decoder_data_optimized handles multi-threading internally, no external + // parallel calls + + // Task 1: build_encoder_tokens_optimized + std::promise*> encoder_promise; + auto encoder_future = encoder_promise.get_future(); + thread_pool_->schedule([&, promise = std::move(encoder_promise)]() mutable { + const auto& result = build_encoder_tokens_optimized(); + promise.set_value(&result); + }); + // Wait for encoder to complete + encoder_tokens_ptr = encoder_future.get(); + // Task 2: build_decoder_data_optimized executes directly, handles + // multi-threading internally + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = build_decoder_data_optimized(); + } else { + // Single-thread execution (original logic) + encoder_tokens_ptr = &build_encoder_tokens_optimized(); + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = build_decoder_data_optimized(); + } + + const auto& encoder_tokens = *encoder_tokens_ptr; + + // ========== High-performance ForwardInput construction ========== + ForwardInput forward_input; + auto& input_params = forward_input.input_params; + auto& onerec_params = input_params.mutable_onerec_params(); + auto& cache_data = perf_cache.cache_data; + + // Initialize key fields for asynchronous tasks + const int64_t bs = sequence_groups_.size(); + const int64_t group_width = + sequence_groups_.empty() ? 1 : sequence_groups_[0]->sequences().size(); + + std::vector> decoder_embedding_futures; + torch::Tensor result_embedding; + + // ========== Parallel tensor construction tasks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Only use parallelization for time-consuming tasks (token_ids and + // encoder_token_ids) + std::promise token_ids_promise; + std::promise encoder_token_ids_promise; + + auto token_ids_future = token_ids_promise.get_future(); + // auto encoder_token_ids_future = encoder_token_ids_promise.get_future(); + + // Task 1: Build token_ids tensor - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + thread_pool_->schedule([&flatten_tokens_vec, + promise = std::move(token_ids_promise)]() mutable { + try { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + auto tensor = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Task 2: Build encoder_token_ids tensor (if needed) - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + /* + thread_pool_->schedule( + [&encoder_tokens, + promise = std::move(encoder_token_ids_promise)]() mutable { + try { + torch::Tensor tensor; + if (!encoder_tokens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid + // clone operations + tensor = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + */ + if (!perf_cache.cache_data.decoder_context_embeddings.empty()) { + // Task 3: Synchronously process decoder_embedding, inner group dimension + // parallelization optimization + + // Optimization: Directly get shape information from first embedding to + // avoid torch::cat + auto first_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + + // Create tensor on pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + + // Calculate total sequence length, pre-allocate context_len + seq_len + int64_t total_seq_len = context_len + seq_len; + + auto combined_embedding = + torch::empty({bs, group_width, total_seq_len, hidden_size}, options); + + // High-performance optimization: group dimension segmented + // parallelization + void* dst_data = combined_embedding.data_ptr(); + + // Get element size (supports float, bfloat16 and other types) + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t group_stride = total_seq_len * hidden_size * element_size; + const size_t batch_stride = + group_width * total_seq_len * hidden_size * element_size; + + // Parallelization strategy: segment by group dimension, consistent with + // thread calculations elsewhere + const size_t num_threads = + std::min(static_cast(group_width), static_cast(16)); + const size_t groups_per_thread = + (group_width + num_threads - 1) / num_threads; + + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_group = thread_idx * groups_per_thread; + size_t end_group = std::min(start_group + groups_per_thread, + static_cast(group_width)); + + if (start_group >= static_cast(group_width)) break; + + std::promise promise; + decoder_embedding_futures.push_back(promise.get_future()); + + thread_pool_->schedule( + [start_group, + end_group, + bs, + dst_data, + context_len, + hidden_size, + element_size, + batch_stride, + group_stride, + context_size, + embeddings = perf_cache.cache_data.decoder_context_embeddings, + dst_tensor = combined_embedding, + promise = std::move(promise)]() mutable { + // Copy context_embedding for specified group range of each batch + for (int64_t b = 0; b < bs; ++b) { + // Optimization: Access corresponding batch embedding directly + // through index + const void* batch_src = embeddings[b].data_ptr(); + auto* batch_dst = + static_cast(dst_data) + b * batch_stride; + + for (size_t g = start_group; g < end_group; ++g) { + std::memcpy( + batch_dst + g * group_stride, batch_src, context_size); + } + } + promise.set_value(); + }); + } + + result_embedding = combined_embedding; + } + + // Task 4: Build sequence length vector - changed to serial execution (very + // time-consuming, ~0.001785ms) + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + // Task 5: Build encoder_seq_lens_tensor - changed to serial execution (less + // time-consuming) + torch::Tensor encoder_seq_lens_tensor; + if (!cache_data.encoder_seq_lens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + + // Set basic parameters simultaneously (not dependent on asynchronous tasks) + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + forward_input.positions = perf_cache.fixed_positions_tensor; + + // Wait and collect results + forward_input.token_ids = token_ids_future.get(); + // auto encoder_token_ids = encoder_token_ids_future.get(); + + // seq_lens has been changed to serial execution, use the constructed + // variable directly + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + // encoder_seq_lens_tensor has been changed to serial execution, use the + // constructed variable directly + if (encoder_seq_lens_tensor.defined()) { + onerec_params.encoder_seq_lens_tensor = + std::move(encoder_seq_lens_tensor); + onerec_params.encoder_seq_lens = cache_data.encoder_seq_lens; + } + onerec_params.encoder_positions = perf_cache.fixed_encoder_positions_tensor; + } else { + // Single-threaded execution (original logic) + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + forward_input.token_ids = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(forward_input.token_ids.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + forward_input.positions = perf_cache.fixed_positions_tensor; + + if (!encoder_tokens.empty()) { + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + onerec_params.encoder_token_ids = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(onerec_params.encoder_token_ids.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + onerec_params.encoder_positions = perf_cache.fixed_encoder_positions_tensor; + // Pre-allocate and batch fill + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + if (!cache_data.encoder_seq_lens.empty()) { + // Set OneRecModelInputParams encoder data + onerec_params.encoder_seq_lens = cache_data.encoder_seq_lens; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + onerec_params.encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(onerec_params.encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + } + + // ========== Parallel processing of independent code blocks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Define promise/future for parallel tasks + std::promise block_tables_promise; + auto block_tables_future = block_tables_promise.get_future(); + + // Task 1: Empty block tables processing - use thread pool (relatively + // time-consuming) + thread_pool_->schedule( + [&input_params, num_sequences, &block_tables_promise]() mutable { + try { + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special + // optimization for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + block_tables_promise.set_value(); + } catch (...) { + block_tables_promise.set_exception(std::current_exception()); + } + }); + + // Optimization: Merge small tasks into sequential execution to reduce + // thread switching overhead Cross-attention parameter construction - use + // placeholder + onerec_params.cross_attn_kv_cu_seq_lens = torch::zeros({1}, torch::kInt); + onerec_params.cross_attn_kv_cu_seq_lens_vec = {0}; + onerec_params.cross_attn_block_tables = torch::zeros({1, 1}, torch::kInt); + + // Sampling parameter processing + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // First prefill processing - use placeholder + if (is_first_prefill) { + // Use placeholder instead of complex cross_attn_new_cache_slots + // construction + onerec_params.cross_attn_new_cache_slots = torch::zeros({1}, torch::kInt); + } + + // Wait for parallel tasks to complete (only block_tables uses thread pool) + block_tables_future.wait(); + } else { + // ========== Non-parallel case: sequential processing ========== + // Optimize empty block tables processing + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special optimization + // for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + // ========== Cross-attention parameter construction (using placeholder) + // ========== Use placeholder tensor instead of actual data + onerec_params.cross_attn_kv_cu_seq_lens = torch::zeros({1}, torch::kInt); + onerec_params.cross_attn_kv_cu_seq_lens_vec = {0}; + + // Use placeholder tensor instead of actual data + onerec_params.cross_attn_block_tables = torch::zeros({1, 1}, torch::kInt); + + // ========== Optimize sampling parameter processing ========== + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // ========== First prefill processing (using placeholder) ========== + if (is_first_prefill) { + // Use placeholder tensor instead of actual data + onerec_params.cross_attn_new_cache_slots = torch::zeros({1}, torch::kInt); + } + } + + // ========== Common parameter settings ========== + // Batch set other parameters + input_params.embedding_ids.assign(num_sequences, 0); + +#if defined(USE_NPU) + auto prefill_indices = util::find_ones_indices(input_params.q_seq_lens_vec); + input_params.decode_seq_range = + std::make_pair(0, static_cast(flatten_tokens_vec.size())); +#else + input_params.decode_seq_range = { + 0, static_cast(flatten_tokens_vec.size())}; +#endif + + // OneRec model parameters + onerec_params.rec_stage = OneRecModelInputParams::RecStage::PREFILL; + onerec_params.is_hybrid_mode = false; + onerec_params.has_encoder_output = true; + onerec_params.is_first_prefill = is_first_prefill; + onerec_params.bs = bs; + onerec_params.group_width = group_width; + onerec_params.seq_len = seq_len; + onerec_params.encoder_max_seq_len = + cache_data.encoder_seq_lens.empty() + ? 0 + : *std::max_element(cache_data.encoder_seq_lens.begin(), + cache_data.encoder_seq_lens.end()); + + onerec_params.generated_tokens = std::move(generated_tokens); + + // Process sparse_embedding: Efficiently concatenate from cache_data + if (!perf_cache.cache_data.encoder_sparse_embeddings.empty()) { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + onerec_params.encoder_sparse_embedding = + torch::cat(perf_cache.cache_data.encoder_sparse_embeddings, /*dim=*/0); + } + + if (!perf_cache.cache_data.decoder_context_embeddings.empty()) { + // Get group_width + const int64_t group_width_val = onerec_params.group_width; + if (group_width_val == 1 && seq_len == 0) { + // Optimization: When bs==1, directly use the first embedding to avoid + // unnecessary torch::cat + if (bs == 1) { + onerec_params.decoder_context_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + } else { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + auto original_context_embedding = torch::cat( + perf_cache.cache_data.decoder_context_embeddings, /*dim=*/0); + onerec_params.decoder_context_embedding = original_context_embedding; + } + } else if (group_width_val == 1 && seq_len > 0) { + // Handle the scenario where group_width==1 and seq_len>0 + // Get information from the first embedding + const auto& first_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + int64_t total_seq_len = context_len + seq_len; + + // Allocate a tensor of shape {bs, 1, total_seq_len, hidden_size}, + // optimized with pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + auto combined_embedding = + torch::empty({bs, 1, total_seq_len, hidden_size}, options); + + // Single-threaded copy of context_len portion of data + void* dst_data = combined_embedding.data_ptr(); + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t batch_stride = total_seq_len * hidden_size * element_size; + + // Copy context_embedding for each batch + for (int64_t b = 0; b < bs; ++b) { + const void* batch_src = + perf_cache.cache_data.decoder_context_embeddings[b].data_ptr(); + auto* batch_dst = static_cast(dst_data) + b * batch_stride; + std::memcpy(batch_dst, batch_src, context_size); + } + onerec_params.decoder_context_embedding = combined_embedding; + } else { + for (auto& future : decoder_embedding_futures) { + future.get(); + } + onerec_params.decoder_context_embedding = std::move(result_embedding); + } + } + + return forward_input; +} + +} // namespace xllm diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.h b/xllm/core/framework/batch/onerec_batch_input_builder.h new file mode 100644 index 000000000..ac5179d88 --- /dev/null +++ b/xllm/core/framework/batch/onerec_batch_input_builder.h @@ -0,0 +1,119 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/mm_data.h" +#include "framework/request/sequence.h" +#include "framework/request/sequences_group.h" +#include "rec_batch_input_builder.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class OneRecBatchInputBuilder : public RecBatchInputBuilder { + public: + explicit OneRecBatchInputBuilder( + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); + + public: + ForwardInput build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) override; + + private: + const std::vector& sequence_groups_; + const std::vector& allowed_max_tokens_; + const std::vector& input_embeddings_vec_; + const std::vector& mm_data_vec_; + std::vector* swap_block_transfer_infos_ = nullptr; + const uint64_t batch_id_; + const ModelArgs* args_ = nullptr; + ThreadPool* thread_pool_ = nullptr; + + // High performance cache system + struct HighPerformanceCache { + // Memory pool - avoid frequent allocation/deallocation + struct MemoryPool { + std::vector> int32_pools; + size_t pool_index = 0; + + std::vector& get_int32_vector(size_t reserve_size = 0) { + if (pool_index >= int32_pools.size()) { + int32_pools.emplace_back(); + } + auto& vec = int32_pools[pool_index++]; + vec.clear(); + if (reserve_size > 0) vec.reserve(reserve_size); + return vec; + } + + void reset() { pool_index = 0; } + }; + + // Cache data structure + struct CacheData { + std::vector encoder_tokens; + std::vector encoder_seq_lens; + std::vector encoder_sparse_embeddings; + std::vector decoder_context_embeddings; + }; + + // Pre-created constant tensors - lazy initialized to avoid static + // initialization order issues + torch::Tensor fixed_positions_tensor; + torch::Tensor fixed_encoder_positions_tensor; + torch::Tensor empty_tensor; + bool tensors_initialized = false; + + MemoryPool memory_pool; + CacheData cache_data; + + // Default constructor - does NOT create tensors to avoid static + // initialization order fiasco + HighPerformanceCache() = default; + + // Lazy initialization of tensors - must be called before first use + void ensure_tensors_initialized() { + if (!tensors_initialized) { + fixed_positions_tensor = torch::tensor({0}, torch::kInt); + fixed_encoder_positions_tensor = torch::tensor({0}, torch::kInt); + empty_tensor = torch::tensor(std::vector{}, torch::kInt); + tensors_initialized = true; + } + } + }; + + // Use function-local static to ensure proper initialization order + // (Meyers' Singleton pattern) + static HighPerformanceCache& get_perf_cache(); +}; + +} // namespace xllm diff --git a/xllm/core/framework/batch/rec_batch_input_builder.cpp b/xllm/core/framework/batch/rec_batch_input_builder.cpp new file mode 100644 index 000000000..b86449e2b --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.cpp @@ -0,0 +1,58 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_batch_input_builder.h" + +#include + +#include +#include + +#include "onerec_batch_input_builder.h" + +namespace xllm { + +std::unique_ptr RecBatchInputBuilder::create( + RecType rec_type, + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool) { + switch (rec_type) { + case RecType::kOneRec: + return std::make_unique( + sequence_groups, + allowed_max_tokens, + input_embeddings_vec, + mm_data_vec, + swap_block_transfer_infos, + batch_id, + args, + thread_pool); + case RecType::kNone: + case RecType::kLlmRec: + break; + } + + LOG(FATAL) << "Unsupported RecType for RecBatchInputBuilder: " + << static_cast(rec_type); + return nullptr; +} + +} // namespace xllm diff --git a/xllm/core/framework/batch/rec_batch_input_builder.h b/xllm/core/framework/batch/rec_batch_input_builder.h new file mode 100644 index 000000000..831bba598 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.h @@ -0,0 +1,53 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +#include "framework/model/model_args.h" +#include "framework/request/mm_data.h" +#include "framework/request/rec_type.h" +#include "framework/request/sequences_group.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecBatchInputBuilder { + public: + virtual ~RecBatchInputBuilder() = default; + + virtual ForwardInput build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) = 0; + + static std::unique_ptr create( + RecType rec_type, + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); +}; + +} // namespace xllm diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 08aa30333..4cd0b2c5c 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include +#include + #if defined(USE_NPU) #include "platform/npu/npu_layer_synchronizer.h" #endif @@ -29,6 +32,117 @@ limitations under the License. namespace xllm { +struct OneRecModelInputParams { + enum class RecStage { + PREFILL, + DECODE, + }; + + RecStage rec_stage = RecStage::PREFILL; + bool is_hybrid_mode = false; + bool is_encoder_forward = false; + bool has_encoder_output = false; + std::vector encoder_seq_lens; + torch::Tensor encoder_seq_lens_tensor; + int32_t encoder_max_seq_len = 0; + + bool is_first_prefill = true; + int32_t bs = 0; + int32_t group_width = 0; + int32_t seq_len = 0; + std::vector> generated_tokens; + torch::Tensor encoder_sparse_embedding; + torch::Tensor decoder_context_embedding; + + torch::Tensor cross_attn_kv_cu_seq_lens; + torch::Tensor cross_attn_new_cache_slots; + torch::Tensor cross_attn_block_tables; + std::vector cross_attn_kv_cu_seq_lens_vec; + + torch::Tensor encoder_token_ids; + torch::Tensor encoder_positions; + + OneRecModelInputParams to(const c10::Device& device) const { + OneRecModelInputParams result = *this; + + if (encoder_seq_lens_tensor.defined()) { + result.encoder_seq_lens_tensor = encoder_seq_lens_tensor.to(device); + } + if (encoder_sparse_embedding.defined()) { + result.encoder_sparse_embedding = encoder_sparse_embedding.to(device); + } + if (decoder_context_embedding.defined()) { + result.decoder_context_embedding = decoder_context_embedding.to(device); + } + if (cross_attn_kv_cu_seq_lens.defined()) { + result.cross_attn_kv_cu_seq_lens = cross_attn_kv_cu_seq_lens.to(device); + } + if (cross_attn_new_cache_slots.defined()) { + result.cross_attn_new_cache_slots = cross_attn_new_cache_slots.to(device); + } + if (cross_attn_block_tables.defined()) { + result.cross_attn_block_tables = cross_attn_block_tables.to(device); + } + if (encoder_token_ids.defined()) { + result.encoder_token_ids = encoder_token_ids.to(device); + } + if (encoder_positions.defined()) { + result.encoder_positions = encoder_positions.to(device); + } + + return result; + } + + void print() const { + LOG(INFO) << "OneRecModelInputParams:" + << " rec_stage: " + << (rec_stage == RecStage::PREFILL ? "PREFILL" : "DECODE") + << " is_hybrid_mode: " << is_hybrid_mode + << " is_encoder_forward: " << is_encoder_forward + << " has_encoder_output: " << has_encoder_output + << " encoder_max_seq_len: " << encoder_max_seq_len + << " is_first_prefill: " << is_first_prefill << " bs: " << bs + << " group_width: " << group_width << " seq_len: " << seq_len + << " encoder_seq_lens size: " << encoder_seq_lens.size() + << " cross_attn_kv_cu_seq_lens_vec size: " + << cross_attn_kv_cu_seq_lens_vec.size() + << " generated_tokens size: " << generated_tokens.size(); + if (encoder_seq_lens_tensor.defined()) { + LOG(INFO) << " encoder_seq_lens_tensor shape: " + << encoder_seq_lens_tensor.sizes(); + } + if (encoder_sparse_embedding.defined()) { + LOG(INFO) << " encoder_sparse_embedding shape: " + << encoder_sparse_embedding.sizes(); + } + if (decoder_context_embedding.defined()) { + LOG(INFO) << " decoder_context_embedding shape: " + << decoder_context_embedding.sizes(); + } + if (cross_attn_kv_cu_seq_lens.defined()) { + LOG(INFO) << " cross_attn_kv_cu_seq_lens shape: " + << cross_attn_kv_cu_seq_lens.sizes(); + } + if (cross_attn_new_cache_slots.defined()) { + LOG(INFO) << " cross_attn_new_cache_slots shape: " + << cross_attn_new_cache_slots.sizes(); + } + if (cross_attn_block_tables.defined()) { + LOG(INFO) << " cross_attn_block_tables shape: " + << cross_attn_block_tables.sizes(); + } + if (encoder_token_ids.defined()) { + LOG(INFO) << " encoder_token_ids shape: " << encoder_token_ids.sizes(); + } + if (encoder_positions.defined()) { + LOG(INFO) << " encoder_positions shape: " << encoder_positions.sizes(); + } + } +}; + +using RecModelInputParams = + std::variant; + enum class TransferType : uint8_t { G2H = 0, // global memory(KVCache store) to host memory(DRAM) H2D = 1, // host memory(DRAM) to device memory(HBM) @@ -159,6 +273,10 @@ struct ModelInputParams { params.batch_id = batch_id; + if (const auto* onerec = onerec_params()) { + params.rec_params = onerec->to(device); + } + return params; } @@ -179,6 +297,11 @@ struct ModelInputParams { print_tensor(block_tables, "ModelInputParams: block_tables", 4); LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; + + if (const auto* onerec = onerec_params()) { + LOG(INFO) << "ModelInputParams: has rec_params"; + onerec->print(); + } } int32_t get_q_seq_len(int32_t seq_idx) const { @@ -294,6 +417,21 @@ struct ModelInputParams { uint64_t batch_id; + RecModelInputParams rec_params; + + const OneRecModelInputParams* onerec_params() const { + return std::get_if(&rec_params); + } + + bool has_onerec_params() const { return onerec_params() != nullptr; } + + OneRecModelInputParams& mutable_onerec_params() { + if (!has_onerec_params()) { + rec_params.emplace(); + } + return std::get(rec_params); + } + struct GraphBuffer { torch::Tensor attn_mask; torch::Tensor tiling_data; diff --git a/xllm/core/framework/request/CMakeLists.txt b/xllm/core/framework/request/CMakeLists.txt index e3f4ba8c7..16f9f9872 100644 --- a/xllm/core/framework/request/CMakeLists.txt +++ b/xllm/core/framework/request/CMakeLists.txt @@ -25,6 +25,7 @@ cc_library( sequence_kv_state.h sequences_group.h request_state.h + rec_type.h stopping_checker.h priority_comparator.h SRCS diff --git a/xllm/core/framework/request/rec_type.h b/xllm/core/framework/request/rec_type.h new file mode 100644 index 000000000..55a4df78f --- /dev/null +++ b/xllm/core/framework/request/rec_type.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +namespace xllm { + +enum class RecType : uint8_t { + kNone = 0, + kOneRec = 1, + kLlmRec = 2, +}; + +} // namespace xllm diff --git a/xllm/core/framework/request/request.cpp b/xllm/core/framework/request/request.cpp index 4a6ff14f4..0c5fa35aa 100644 --- a/xllm/core/framework/request/request.cpp +++ b/xllm/core/framework/request/request.cpp @@ -56,6 +56,8 @@ void Request::create_sequences_group() { sequence_params.best_of = state_.best_of; sequence_params.streaming = state_.stream; sequence_params.enable_schedule_overlap = state_.enable_schedule_overlap; + sequence_params.rec_type = state_.rec_type; + sequence_params.bos_token_id = state_.bos_token_id; sequence_params.sampling_param = &(state_.sampling_param); sequence_params.stopping_checker = &(state_.stopping_checker); sequences_group_ = std::make_unique(state_.prompt, diff --git a/xllm/core/framework/request/request_state.h b/xllm/core/framework/request/request_state.h index 5ff043220..5558c9734 100644 --- a/xllm/core/framework/request/request_state.h +++ b/xllm/core/framework/request/request_state.h @@ -25,6 +25,7 @@ limitations under the License. #include "core/framework/sampling/sampling_params.h" #include "mm_data.h" +#include "rec_type.h" #include "request_output.h" #include "stopping_checker.h" @@ -137,6 +138,10 @@ struct RequestState final { bool enable_schedule_overlap = false; + RecType rec_type = RecType::kNone; + + int32_t bos_token_id = 0; + // The thread id of the thread pool in the response handler to ensure that // stream responses for the same request are executed sequentially during // multi-threaded stream processing. diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 346705d5b..ea346fd22 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -34,6 +34,85 @@ limitations under the License. namespace xllm { +namespace { +constexpr size_t kDecoderBosTokenCount = 1; +constexpr size_t kDecoderMaxTokenCount = 4; +} // namespace + +const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; +const std::string Sequence::DECODER_CONTEXT_EMBEDDING_NAME = + "decoder_context_embedding"; + +void Sequence::init_onerec_sequence( + const std::vector& prompt_token_ids, + torch::Tensor input_embedding) { + auto& onerec_state = onerec_state_.emplace(); + if (!prompt_token_ids.empty()) { + onerec_state.encoder_tokens.assign(prompt_token_ids.begin(), + prompt_token_ids.end()); + onerec_state.num_encoder_tokens = prompt_token_ids.size(); + } else { + auto encoder_sparse_embedding = + mm_data_.get(ENCODER_SPARSE_EMBEDDING_NAME); + CHECK(encoder_sparse_embedding.has_value()) + << "encoder sparse embedding not found in mm_data"; + onerec_state.num_encoder_tokens = encoder_sparse_embedding.value().size(0); + } + + auto decoder_context_embedding = + mm_data_.get(DECODER_CONTEXT_EMBEDDING_NAME); + + size_t capacity = kDecoderMaxTokenCount; + if (decoder_context_embedding.has_value()) { + num_prompt_tokens_ = 0; + onerec_state.num_decoder_embeddings = + decoder_context_embedding.value().size(0); + capacity = onerec_state.num_decoder_embeddings + capacity - + kDecoderBosTokenCount; + } else { + num_prompt_tokens_ = kDecoderBosTokenCount; + } + + tokens_.resize(capacity); + for (size_t i = 0; i < num_prompt_tokens_; ++i) { + tokens_[num_tokens_++] = sequence_params_.bos_token_id; + token_to_count_map_[sequence_params_.bos_token_id]++; + } + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = std::move(input_embedding); + cur_generated_token_idx_ = num_prompt_tokens_; + logprob_state_ = std::make_unique(num_prompt_tokens_, capacity); +} + +SequenceOutput Sequence::build_onerec_output(const Slice& ids, + size_t size, + SequenceOutput output) const { + output.token_ids = ids.slice(num_prompt_tokens_, size); + return output; +} + +SequenceOutput Sequence::build_onerec_streaming_output( + const Slice& ids, + size_t size) const { + SequenceOutput output; + output.index = index_; + output.token_ids = ids.slice(num_prompt_tokens_, size); + return output; +} + +SequenceOutput Sequence::generate_onerec_output(const Slice& ids, + size_t size) const { + SequenceOutput output; + output.index = index_; + if (output_embedding_.defined()) { + output.embedding = output_embedding_; + } + if (finish_reason_ != FinishReason::NONE) { + output.finish_reason = finish_reason_.to_string(); + } + return build_onerec_output(ids, size, std::move(output)); +} + Sequence::Sequence(size_t index, const std::vector& prompt_token_ids, torch::Tensor input_embedding, @@ -45,7 +124,13 @@ Sequence::Sequence(size_t index, latest_generate_time_(absl::Now()), sequence_params_(seq_params), decoder_(std::move(decoder)), + rec_type_(seq_params.rec_type), termination_flag_(std::make_shared>(INT32_MAX)) { + if (is_onerec_model()) { + init_onerec_sequence(prompt_token_ids, std::move(input_embedding)); + return; + } + CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; auto capacity = sequence_params_.seq_capacity; CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; @@ -85,6 +170,8 @@ Sequence::Sequence(const Sequence& other) num_tokens_(other.num_tokens_), token_to_count_map_(other.token_to_count_map_), num_prompt_tokens_(other.num_prompt_tokens_), + onerec_state_(other.onerec_state_), + rec_type_(other.rec_type_), volatile_num_prompt_tokens_(other.volatile_num_prompt_tokens_), embedding_id_(other.embedding_id_), finished_(other.finished_), @@ -103,8 +190,10 @@ void Sequence::append_token(const Token& token) { CHECK_LT(num_tokens_, tokens_.size()) << "exceed the token capacity of the sequence"; CHECK(!finished_) << "cannot append token to a finished sequence"; - CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage()) - << "cannot append token to a prefill sequence"; + if (!is_onerec_model()) { + CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage()) + << "cannot append token to a prefill sequence"; + } if (!sequence_params_.enable_schedule_overlap) { // check if the token is the first token after the prompt @@ -250,6 +339,10 @@ std::optional Sequence::generate_streaming_output( AUTO_COUNTER(detokenization_latency_seconds_stream); const auto ids = Slice(tokens_, size); + if (is_onerec_model()) { + return build_onerec_streaming_output(ids, size); + } + // record the start index of token ids const size_t start = decoder_.output_offset(); auto delta = decoder_.decode(ids, tokenizer); @@ -333,6 +426,19 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { } } + if (is_onerec_model()) { + return generate_onerec_output(ids, size); + } + + SequenceOutput output; + output.index = index_; + if (output_embedding_.defined()) { + output.embedding = output_embedding_; + } + if (finish_reason_ != FinishReason::NONE) { + output.finish_reason = finish_reason_.to_string(); + } + // record the start index of token ids const size_t start = decoder_.output_offset(); @@ -349,16 +455,7 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { ss << decoder_.decode(ids.slice(0, end), tokenizer); } - SequenceOutput output; - output.index = index_; output.text = ss.str(); - if (output_embedding_.defined()) { - output.embedding = output_embedding_; - } - - if (finish_reason_ != FinishReason::NONE) { - output.finish_reason = finish_reason_.to_string(); - } const size_t end = decoder_.output_offset(); output.token_ids = ids.slice(start, end); @@ -402,6 +499,10 @@ bool Sequence::finished() const { return finished_; } + if (is_onerec_model() && num_tokens_ == num_prompt_tokens_) { + return false; + } + // Embedding sequence never be finished until it updates its embeddings if (finish_status_invalidated_ && sequence_params_.sampling_param->is_embeddings) { diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index 520929d04..dd5609ed9 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include +#include #include #include "core/common/types.h" @@ -30,6 +32,7 @@ limitations under the License. #include "framework/block/block.h" #include "incremental_decoder.h" #include "mm_data.h" +#include "rec_type.h" #include "request_output.h" #include "sequence_kv_state.h" #include "sequence_logprob_state.h" @@ -73,6 +76,10 @@ struct SequenceParams { // enable_schedule_overlap or not. default = false. bool enable_schedule_overlap = false; + RecType rec_type = RecType::kNone; + + int32_t bos_token_id = 0; + // sampling params // reference from request RequestSamplingParam* sampling_param; // not owned @@ -268,11 +275,57 @@ class Sequence final { // get sequence id int32_t seq_id() const { return seq_id_; } + const std::vector& encoder_tokens() const { + static const std::vector kEmpty; + if (!onerec_state_.has_value()) { + return kEmpty; + } + return onerec_state_->encoder_tokens; + } + + size_t encoder_seq_len() const { + return onerec_state_.has_value() ? onerec_state_->num_encoder_tokens + : 0; + } + + size_t num_decoder_embeddings() const { + return onerec_state_.has_value() ? onerec_state_->num_decoder_embeddings + : 0; + } + + RecType rec_type() const { return rec_type_; } + + bool is_rec_request() const { return rec_type_ != RecType::kNone; } + + bool is_onerec_model() const { return rec_type_ == RecType::kOneRec; } + + static const std::string ENCODER_SPARSE_EMBEDDING_NAME; + static const std::string DECODER_CONTEXT_EMBEDDING_NAME; + void set_cancel() { cancelled_.store(true, std::memory_order_relaxed); } bool cancelled() const { return cancelled_.load(std::memory_order_relaxed); } private: + void init_onerec_sequence(const std::vector& prompt_token_ids, + torch::Tensor input_embedding); + + SequenceOutput build_onerec_output(const Slice& ids, + size_t size, + SequenceOutput output) const; + + SequenceOutput build_onerec_streaming_output(const Slice& ids, + size_t size) const; + + SequenceOutput generate_onerec_output(const Slice& ids, + size_t size) const; + + struct OneRecState { + size_t num_encoder_tokens = 0; + size_t num_decoder_embeddings = 0; + std::vector encoder_tokens; + }; + // the index of the sequence in the request size_t index_ = 0; @@ -319,6 +372,10 @@ class Sequence final { // the length of the prompt tokens size_t num_prompt_tokens_ = 0; + std::optional onerec_state_; + + RecType rec_type_ = RecType::kNone; + // NOTE: MUST FIXME Later // record all tokens num in last turn when the request is // interrupted due to the lack of kv cache capacity. diff --git a/xllm/core/scheduler/fixed_steps_scheduler.cpp b/xllm/core/scheduler/fixed_steps_scheduler.cpp index 0d0d51411..bbbb0131b 100644 --- a/xllm/core/scheduler/fixed_steps_scheduler.cpp +++ b/xllm/core/scheduler/fixed_steps_scheduler.cpp @@ -33,7 +33,6 @@ limitations under the License. #include "framework/request/sequence.h" namespace xllm { - FixedStepsScheduler::FixedStepsScheduler(Engine* engine, const Options& options) : ContinuousScheduler(engine, options) {} @@ -217,21 +216,12 @@ std::vector FixedStepsScheduler::prepare_batch() { response_processor_->process_completed_requests(finished_requests); } - // update the batch - // TODO. add following when next pr (use create_rec_batches). - // auto batches = BatchFactory::get_instance(options_.dp_size()) - // ->create_rec_batches( - // running_requests_, - // running_sequences_, - // running_sequences_budgets_, - // kv_cache_manager_->get_swap_block_transfer_infos()); - // TODO. delete this when next pr. - auto batches = - BatchFactory::get_instance(options_.dp_size()) - ->create_batches(running_requests_, - running_sequences_, - running_sequences_budgets_, - kv_cache_manager_->get_swap_block_transfer_infos()); + auto batches = BatchFactory::get_instance(options_.dp_size()) + ->create_rec_batches( + running_requests_, + running_sequences_, + running_sequences_budgets_, + kv_cache_manager_->get_swap_block_transfer_infos()); // update metrics before returning if (!batches[0].empty()) {