diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp index 9ffc4261b..f09ed0367 100644 --- a/xllm/api_service/rec_completion_service_impl.cpp +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -28,9 +28,7 @@ limitations under the License. #include "completion.pb.h" #include "core/distributed_runtime/llm_master.h" #include "core/distributed_runtime/rec_master.h" -#include "core/framework/request/mm_data.h" #include "core/framework/request/request_output.h" -#include "core/util/utils.h" #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) @@ -167,18 +165,15 @@ void RecCompletionServiceImpl::process_async_impl( } const auto& rpc_request_ref = call->request(); - std::optional mm_data = std::nullopt; + std::optional> input_tensors = + std::nullopt; if (rpc_request_ref.input_tensors_size()) { - // HISTOGRAM_OBSERVE(rec_input_first_dim, - // rpc_request_ref.input_tensors(0).shape(0)); - - MMDict mm_dict; + std::vector tensors; + tensors.reserve(rpc_request_ref.input_tensors_size()); for (int i = 0; i < rpc_request_ref.input_tensors_size(); ++i) { - const auto& tensor = rpc_request_ref.input_tensors(i); - mm_dict[tensor.name()] = - xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + tensors.push_back(rpc_request_ref.input_tensors(i)); } - mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict)); + input_tensors = std::move(tensors); } // schedule the request @@ -187,7 +182,7 @@ void RecCompletionServiceImpl::process_async_impl( master_->handle_request( std::move(rpc_request_ref.prompt()), std::move(prompt_tokens), - std::move(mm_data), + std::move(input_tensors), std::move(request_params), [call, model, @@ -219,4 +214,4 @@ void RecCompletionServiceImpl::process_async_impl( }); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index f8d54c2ec..9a047a45d 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library( $<$:mspti_helper.h> options.h rate_limiter.h + rec_model_utils.h types.h device_monitor.h version_singleton.h @@ -66,4 +67,3 @@ cc_test( target_link_libraries(common PRIVATE OpenSSL::SSL OpenSSL::Crypto protobuf::libprotobuf) add_dependencies(common brpc-static) - diff --git a/xllm/core/common/rec_model_utils.h b/xllm/core/common/rec_model_utils.h new file mode 100644 index 000000000..4aa5a06d4 --- /dev/null +++ b/xllm/core/common/rec_model_utils.h @@ -0,0 +1,48 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +namespace xllm { + +enum class RecModelKind : int8_t { + kNone = 0, + kOneRec = 1, + kLlmRec = 2, +}; + +inline constexpr bool is_onerec_model_type(std::string_view model_type) { + return model_type == "onerec"; +} + +inline constexpr bool is_llmrec_model_type(std::string_view model_type) { + return model_type == "qwen2" || model_type == "qwen3"; +} + +inline constexpr RecModelKind get_rec_model_kind(std::string_view model_type) { + if (is_onerec_model_type(model_type)) { + return RecModelKind::kOneRec; + } + if (is_llmrec_model_type(model_type)) { + return RecModelKind::kLlmRec; + } + return RecModelKind::kNone; +} + +} // namespace xllm + diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 4f0dbb14b..8712e7a82 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -292,4 +292,9 @@ struct EplbInfo { inline constexpr int REC_TOKEN_SIZE = 3; using RecTokenTriple = std::array; + +inline constexpr const char* LLM_REC_INPUT_TOKENS = "llm_rec_input_tokens"; +inline constexpr const char* LLM_REC_INPUT_INDICES = "llm_rec_input_indices"; +inline constexpr const char* LLM_REC_INPUT_EMBEDDING = + "llm_rec_input_embedding"; } // namespace xllm diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 9fe82e593..0bf1b1844 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "common/types.h" #include "dit_master.h" #include "framework/model/model_args.h" +#include "framework/model_loader.h" #include "framework/request/request.h" #include "llm_engine.h" #include "llm_master.h" @@ -237,8 +239,13 @@ Master::Master(const Options& options, EngineType type) : options_(options) { options_.enable_schedule_overlap(false); LOG(WARNING) << "Force to disable schedule overlap for REC model, not " "supported yet."; + + auto model_loader = ModelLoader::create(options_.model_path()); + const std::string rec_model_type = model_loader->model_args().model_type(); + runtime::Options eng_options; eng_options.model_path(options_.model_path()) + .rec_model_type(std::make_optional(rec_model_type)) .devices(devices) .backend(options_.backend()) .block_size(options_.block_size()) diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp index 0bb10a901..5eeb1ea19 100644 --- a/xllm/core/distributed_runtime/rec_master.cpp +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -19,9 +19,15 @@ limitations under the License. #include #include #include +#include + +#include #include "common/macros.h" #include "common/metrics.h" +#include "common/rec_model_utils.h" +#include "common/types.h" +#include "framework/request/mm_data.h" #include "models/model_registry.h" #include "rec_engine.h" #include "runtime/xservice_client.h" @@ -32,6 +38,184 @@ limitations under the License. namespace xllm { +namespace { + +constexpr int32_t kDefaultPlaceholderToken = 20152019; + +RecType get_rec_type(const ModelArgs& model_args) { + const auto kind = get_rec_model_kind(model_args.model_type()); + switch (kind) { + case RecModelKind::kOneRec: + return RecType::kOneRec; + case RecModelKind::kLlmRec: + return RecType::kLlmRec; + case RecModelKind::kNone: + return RecType::kNone; + } + return RecType::kNone; +} + +bool process_onerec_inputs( + const std::optional>& prompt_tokens, + const std::optional>& input_tensors, + std::vector* local_prompt_tokens, + MMData* processed_mm_data, + OutputCallback callback) { + if (prompt_tokens.has_value() && input_tensors.has_value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "prompt_tokens and input_tensors cannot both be set"); + return false; + } + + if (prompt_tokens.has_value()) { + local_prompt_tokens->assign(prompt_tokens.value().begin(), + prompt_tokens.value().end()); + } + + if (input_tensors.has_value()) { + MMDict mm_dict; + mm_dict.reserve(input_tensors->size()); + for (const auto& tensor : input_tensors.value()) { + mm_dict[tensor.name()] = + util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + } + *processed_mm_data = MMData(MMType::EMBEDDING, mm_dict); + } + + if (local_prompt_tokens->empty() && !processed_mm_data->valid()) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Rec model requires prompt_tokens or input_tensors to be provided"); + return false; + } + + return true; +} + +bool process_llmrec_raw_inputs( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const ModelArgs& model_args, + std::vector* local_prompt_tokens, + MMData* processed_mm_data, + OutputCallback callback) { + std::vector local_input_tokens; + std::vector local_input_indices; + torch::Tensor input_tokens_tensor; + torch::Tensor input_indices_tensor; + torch::Tensor input_embedding_tensor; + int64_t embedding_rows = 0; + + if (input_tokens.has_value()) { + const auto& tokens = input_tokens.value(); + local_input_tokens.reserve(tokens.size()); + for (const auto token : tokens) { + local_input_tokens.push_back(static_cast(token)); + } + if (!local_input_tokens.empty()) { + input_tokens_tensor = + torch::from_blob(local_input_tokens.data(), + {static_cast(local_input_tokens.size())}, + torch::dtype(torch::kInt32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_TOKENS, input_tokens_tensor); + local_prompt_tokens->assign(local_input_tokens.begin(), + local_input_tokens.end()); + } + } + + if (input_indices.has_value()) { + if (!input_tokens.has_value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices require input tokens"); + return false; + } + const auto& indices = input_indices.value(); + local_input_indices.reserve(indices.size()); + for (const auto index : indices) { + local_input_indices.push_back(static_cast(index)); + } + if (local_input_indices.size() != local_input_tokens.size()) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "LLMRec input indices size does not match input tokens"); + return false; + } + if (!local_input_indices.empty()) { + input_indices_tensor = + torch::from_blob(local_input_indices.data(), + {static_cast(local_input_indices.size())}, + torch::dtype(torch::kInt32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_INDICES, input_indices_tensor); + } + } + + if (input_embedding.has_value()) { + const auto& embedding_vec = input_embedding.value(); + if (embedding_vec.empty()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input embedding is empty"); + return false; + } + const int64_t rows = static_cast(embedding_vec.size()); + const int64_t cols = static_cast(embedding_vec[0].size()); + if (cols != model_args.hidden_size()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input embedding has invalid hidden size"); + return false; + } + + std::vector flat_data; + flat_data.reserve(static_cast(rows * cols)); + for (const auto& row : embedding_vec) { + flat_data.insert(flat_data.end(), row.begin(), row.end()); + } + input_embedding_tensor = + torch::from_blob(flat_data.data(), + {rows, cols}, + torch::dtype(torch::kFloat32).device(torch::kCPU)) + .clone(); + processed_mm_data->add( + MMType::EMBEDDING, LLM_REC_INPUT_EMBEDDING, input_embedding_tensor); + embedding_rows = rows; + local_prompt_tokens->insert(local_prompt_tokens->end(), + static_cast(embedding_rows), + kDefaultPlaceholderToken); + } + + if (!local_input_indices.empty()) { + const int64_t total_size = + static_cast(local_input_tokens.size()) + embedding_rows; + std::unordered_set seen; + seen.reserve(local_input_indices.size()); + for (const auto index : local_input_indices) { + if (index < 0 || index >= total_size) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices contain invalid values"); + return false; + } + if (!seen.insert(index).second) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec input indices contain duplicate values"); + return false; + } + } + } + + if (local_prompt_tokens->empty()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is empty"); + return false; + } + + return true; +} + +} // namespace + RecMaster::RecMaster(const Options& options) : Master(options, EngineType::REC) { // Initialize with Rec engine type @@ -39,6 +223,10 @@ RecMaster::RecMaster(const Options& options) CHECK(engine_->init()); model_args_ = engine_->model_args(); + rec_type_ = get_rec_type(model_args_); + if (rec_type_ == RecType::kNone) { + LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type(); + } bool enable_decode_response_to_service = false; if (options_.enable_service_routing()) { @@ -71,7 +259,6 @@ RecMaster::RecMaster(const Options& options) .enable_decode_response_to_service(enable_decode_response_to_service); scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options); - // OmniRec model does not have a tokenizer chat_template_ = nullptr; tokenizer_ = nullptr; threadpool_ = @@ -108,39 +295,83 @@ RecMaster::~RecMaster() { void RecMaster::handle_request(std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, + std::optional> + input_tensors, RequestParams sp, OutputCallback callback) { - // add one pending request + if (rec_type_ != RecType::kOneRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "OneRec should use onerec input interface"); + return; + } + schedule_request( + std::move(sp), + std::move(callback), + [this, + prompt = std::move(prompt), + prompt_tokens = std::move(prompt_tokens), + input_tensors = std::move(input_tensors)]( + const RequestParams& params, + OutputCallback cb) mutable { + return generate_request(std::move(prompt), + std::move(prompt_tokens), + std::move(input_tensors), + params, + std::move(cb)); + }); +} + +void RecMaster::handle_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + RequestParams sp, + OutputCallback callback) { + if (rec_type_ != RecType::kLlmRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec should use raw input interface"); + return; + } + schedule_request( + std::move(sp), + std::move(callback), + [this, + input_tokens = std::move(input_tokens), + input_indices = std::move(input_indices), + input_embedding = std::move(input_embedding)]( + const RequestParams& params, + OutputCallback cb) mutable { + return generate_request(std::move(input_tokens), + std::move(input_indices), + std::move(input_embedding), + params, + std::move(cb)); + }); +} + +void RecMaster::schedule_request(RequestParams sp, + OutputCallback callback, + RequestBuilder build_request) { scheduler_->incr_pending_requests(1); auto cb = [callback = std::move(callback), scheduler = scheduler_.get()](const RequestOutput& output) { output.log_request_status(); return callback(output); }; - // add into the queue threadpool_->schedule([this, - prompt = std::move(prompt), - prompt_tokens = std::move(prompt_tokens), - mm_data = std::move(mm_data), sp = std::move(sp), - callback = std::move(cb)]() mutable { + callback = std::move(cb), + build_request = std::move(build_request)]() mutable { AUTO_COUNTER(request_handling_latency_seconds_completion); - // remove the pending request after scheduling SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); }); Timer timer; - // verify the prompt if (!sp.verify_params(callback)) { return; } - auto request = generate_request(std::move(prompt), - std::move(prompt_tokens), - std::move(mm_data), - sp, - callback); + auto request = build_request(sp, std::move(callback)); if (!request) { return; } @@ -155,40 +386,97 @@ void RecMaster::handle_request(std::string prompt, std::shared_ptr RecMaster::generate_request( std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, - RequestParams sp, + std::optional> input_tensors, + const RequestParams& sp, OutputCallback callback) { // For Rec model, prompt is expected to be empty and prompt_tokens should // contain the actual data Skip prompt empty check as mentioned in // requirements - Timer timer; - std::vector local_prompt_tokens; - - if (prompt_tokens.has_value()) { - local_prompt_tokens = std::move(prompt_tokens.value()); - LOG(INFO) - << "[Rec DEBUG] generate_request - received prompt_tokens.size(): " - << local_prompt_tokens.size() - << ", prompt.length(): " << prompt.length(); - } else if (!mm_data.has_value()) { - // sparse LLM - LOG(ERROR) << "Rec model requires prompt_tokens/embedding to be provided"; + if (rec_type_ == RecType::kNone) { + LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type(); CALLBACK_WITH_ERROR( StatusCode::INVALID_ARGUMENT, - "Rec model requires prompt_tokens/embedding to be provided"); + std::string("Unsupported rec model_type: ") + model_args_.model_type()); + return nullptr; + } + + Timer timer; + std::vector local_prompt_tokens; + MMData processed_mm_data; + bool processed_ok = false; + + if (rec_type_ == RecType::kOneRec) { + processed_ok = process_onerec_inputs(prompt_tokens, + input_tensors, + &local_prompt_tokens, + &processed_mm_data, + callback); + } + + if (!processed_ok) { return nullptr; } COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + return build_request_common(std::move(prompt), + std::move(local_prompt_tokens), + std::move(processed_mm_data), + sp, + callback, + rec_type_ == RecType::kLlmRec); +} + +std::shared_ptr RecMaster::generate_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const RequestParams& sp, + OutputCallback callback) { + if (rec_type_ != RecType::kLlmRec) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "LLMRec inputs require rec_type kLlmRec"); + return nullptr; + } + + Timer timer; + std::vector local_prompt_tokens; + MMData processed_mm_data; + if (!process_llmrec_raw_inputs(std::move(input_tokens), + std::move(input_indices), + std::move(input_embedding), + model_args_, + &local_prompt_tokens, + &processed_mm_data, + callback)) { + return nullptr; + } + + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + return build_request_common(std::string(""), + std::move(local_prompt_tokens), + std::move(processed_mm_data), + sp, + callback, + true); +} + +std::shared_ptr RecMaster::build_request_common( + std::string prompt, + std::vector prompt_tokens, + MMData mm_data, + const RequestParams& sp, + OutputCallback callback, + bool build_stop_checker) { int32_t max_context_len = model_args_.max_position_embeddings(); if (!options_.enable_chunked_prefill()) { max_context_len = std::min(max_context_len, options_.max_tokens_per_batch()); } - if (local_prompt_tokens.size() >= max_context_len) { - LOG(ERROR) << "Prompt is too long: " << local_prompt_tokens.size(); + if (prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << prompt_tokens.size(); CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); return nullptr; } @@ -199,9 +487,7 @@ std::shared_ptr RecMaster::generate_request( max_tokens = kDefaultMaxTokens; } - // allocate enough capacity for prompt tokens, max tokens, and speculative - // tokens - size_t capacity = local_prompt_tokens.size() + max_tokens + + size_t capacity = prompt_tokens.size() + max_tokens + options_.num_speculative_tokens() + /*bonus_token*/ 1; if (options_.enable_schedule_overlap()) { capacity += options_.num_speculative_tokens() + 1; @@ -220,30 +506,55 @@ std::shared_ptr RecMaster::generate_request( sampling_param.is_embeddings = sp.is_embeddings; sampling_param.beam_width = sp.beam_width; if (best_of > sp.n) { - // enable logprobs for best_of to generate sequence logprob sampling_param.logprobs = true; } - // sampling_param.do_sample = sp.do_sample; bool stream = sp.streaming; - // results cannot be streamed when best_of != n if (best_of != sp.n) { stream = false; } - // std::unordered_set stop_tokens; - // std::vector> stop_sequences; - // StoppingChecker stopping_checker( - // max_tokens, - // max_context_len - options_.num_speculative_tokens(), - // , - // model_args_.eos_token_id(), - // sp.ignore_eos, - // std::move(stop_tokens), - // std::move(stop_sequences)); + StoppingChecker stopping_checker; + if (build_stop_checker) { + std::unordered_set stop_tokens; + if (sp.stop_token_ids.has_value()) { + const auto& stop_token_ids = sp.stop_token_ids.value(); + stop_tokens.insert(stop_token_ids.begin(), stop_token_ids.end()); + } else { + stop_tokens = model_args_.stop_token_ids(); + } + + std::vector> stop_sequences; + if (sp.stop.has_value()) { + if (!tokenizer_) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Tokenizer is required for stop sequences"); + return nullptr; + } + for (const auto& s : sp.stop.value()) { + std::vector tmp_tokens; + if (!tokenizer_->encode(s, &tmp_tokens)) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Failed to encode stop sequence"); + LOG(ERROR) << "Failed to encode stop sequence: " << s; + return nullptr; + } + stop_sequences.push_back(std::move(tmp_tokens)); + } + } + + stopping_checker = StoppingChecker( + max_tokens, + max_context_len - options_.num_speculative_tokens(), + model_args_.eos_token_id(), + sp.ignore_eos, + std::move(stop_tokens), + std::move(stop_sequences)); + } + RequestState req_state(std::move(prompt), - std::move(local_prompt_tokens), - mm_data.value_or(MMData{}), + std::move(prompt_tokens), + std::move(mm_data), std::move(sampling_param), std::move(stopping_checker), capacity, @@ -257,9 +568,8 @@ std::shared_ptr RecMaster::generate_request( callback, nullptr, sp.decode_address); - // TODO. add following when next pr (add is_rec_model and bos_token_id to - // RequestState). req_state.is_rec_model = true; req_state.bos_token_id = - // model_args_.bos_token_id(); + req_state.rec_type = rec_type_; + req_state.bos_token_id = model_args_.bos_token_id(); auto request = std::make_shared(sp.request_id, sp.x_request_id, sp.x_request_time, diff --git a/xllm/core/distributed_runtime/rec_master.h b/xllm/core/distributed_runtime/rec_master.h index 0ed5b76d3..2e9fe0396 100644 --- a/xllm/core/distributed_runtime/rec_master.h +++ b/xllm/core/distributed_runtime/rec_master.h @@ -16,12 +16,14 @@ limitations under the License. #pragma once #include +#include #include #include "framework/chat_template/jinja_chat_template.h" #include "framework/model/model_args.h" #include "master.h" #include "rec_engine.h" +#include "rec.pb.h" #include "scheduler/continuous_scheduler.h" #include "scheduler/fixed_steps_scheduler.h" #include "util/threadpool.h" @@ -37,21 +39,52 @@ class RecMaster : public Master { // completion/encode void handle_request(std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, + std::optional> + input_tensors, RequestParams sp, OutputCallback callback); + void handle_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + RequestParams sp, + OutputCallback callback); + // start the handling loop void run() override; private: + using RequestBuilder = + std::function(const RequestParams&, + OutputCallback)>; + + void schedule_request(RequestParams sp, + OutputCallback callback, + RequestBuilder build_request); + std::shared_ptr generate_request( std::string prompt, std::optional> prompt_tokens, - std::optional mm_data, - RequestParams sp, + std::optional> input_tensors, + const RequestParams& sp, OutputCallback callback); + std::shared_ptr generate_request( + std::optional> input_tokens, + std::optional> input_indices, + std::optional>> input_embedding, + const RequestParams& sp, + OutputCallback callback); + + std::shared_ptr build_request_common( + std::string prompt, + std::vector prompt_tokens, + MMData mm_data, + const RequestParams& sp, + OutputCallback callback, + bool build_stop_checker); + std::unique_ptr scheduler_; // model args ModelArgs model_args_; diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index 5699cc818..d531b4f6e 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/request/sequence.h" #include "framework/sampling/sampling_params.h" +#include "rec_batch_input_builder.h" #include "runtime/params_utils.h" #include "util/slice.h" #include "util/tensor_helper.h" @@ -96,6 +97,10 @@ void Batch::add(const std::vector& sequences) { ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size, const ModelArgs& args) { + if (sequences_.empty() && !sequence_groups_.empty()) { + return prepare_rec_forward_input( + num_decoding_tokens, min_decoding_batch_size, args); + } BatchInputBuilder builder(sequences_, allowed_max_tokens_, input_embeddings_vec_, @@ -108,6 +113,43 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, min_decoding_batch_size); } +ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool) { + RecType rec_type = RecType::kNone; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + rec_type = sequence_groups_[0]->sequences()[0]->rec_type(); + } + + auto builder = RecBatchInputBuilder::create(rec_type, + sequence_groups_, + allowed_max_tokens_, + input_embeddings_vec_, + mm_data_vec_, + swap_block_transfer_infos_, + batch_id_, + &args, + thread_pool); + return builder->build_rec_forward_input(num_decoding_tokens, + min_decoding_batch_size); +} + +std::vector Batch::get_sequences() const { + if (!sequences_.empty()) { + return sequences_; + } + + std::vector result; + for (const auto* seq_group : sequence_groups_) { + const auto& sequences = seq_group->sequences(); + for (const auto& seq_ptr : sequences) { + result.push_back(seq_ptr.get()); + } + } + return result; +} + void Batch::dp_balance_shuffle_seqs() { // this shuffle operation is mainly used for npu with 24 cores // and specific mla op implementation @@ -217,7 +259,8 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output, // this means all sequences are in prefill stage status. const int64_t num_seqs = raw_output.outputs.size(); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { if (seq->finished()) { output_idx++; continue; @@ -264,7 +307,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, if (sample_output.embeddings.defined()) { const int64_t num_seqs = sample_output.embeddings.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { CHECK_LT(output_idx, num_seqs); auto cur_seq_embed = safe_to(sample_output.embeddings[output_idx++], torch::kFloat32); @@ -277,7 +321,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, // this means all sequences are in prefill stage status. const int64_t num_seqs = sample_output.next_tokens.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto sequences = get_sequences(); + for (auto* seq : sequences) { if (seq->finished()) { output_idx++; continue; @@ -352,7 +397,8 @@ void Batch::process_embedding_output(const torch::Tensor& output_embedding) { Token token(0); if (output_embedding.defined()) { int32_t slice_img_index = 0; - for (auto* seq : sequences_) { // TODO + const auto sequences = get_sequences(); + for (auto* seq : sequences) { const auto& mm_data = seq->get_mm_data(); auto pixel_values = mm_data.get_tensor_vec("pixel_values"); diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_batch_input_builder.cpp new file mode 100644 index 000000000..782595a60 --- /dev/null +++ b/xllm/core/framework/batch/onerec_batch_input_builder.cpp @@ -0,0 +1,957 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "onerec_batch_input_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "framework/model/model_input_params.h" +#include "framework/request/sequence.h" +#include "framework/sampling/sampling_params.h" +#include "util/tensor_helper.h" +#include "util/utils.h" + +namespace xllm { + +// Use Meyers' Singleton pattern to avoid static initialization order fiasco +// This ensures the cache is initialized on first use, after all dependencies +// (like PyTorch runtime) are properly initialized. +OneRecBatchInputBuilder::HighPerformanceCache& +OneRecBatchInputBuilder::get_perf_cache() { + static HighPerformanceCache cache; + cache.ensure_tensors_initialized(); + return cache; +} + +OneRecBatchInputBuilder::OneRecBatchInputBuilder( + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool) + : sequence_groups_(sequence_groups), + allowed_max_tokens_(allowed_max_tokens), + input_embeddings_vec_(input_embeddings_vec), + mm_data_vec_(mm_data_vec), + swap_block_transfer_infos_(swap_block_transfer_infos), + batch_id_(batch_id), + args_(args), + thread_pool_(thread_pool) { + // Get references to function-local statics (safe initialization) + auto& perf_cache = get_perf_cache(); + perf_cache.memory_pool.reset(); +} + +ForwardInput OneRecBatchInputBuilder::build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) { + // Get reference to function-local static cache (safe initialization) + auto& perf_cache = get_perf_cache(); + + // ========== Global constant cache ========== + // Note: FIXED_POSITIONS is a simple vector, safe for static initialization + static const std::vector FIXED_POSITIONS = {0}; + // Note: FIXED_ENCODER_POSITIONS is now obtained from perf_cache to avoid + // static initialization order issues with torch::Tensor + + // ========== Fast sequence information extraction ========== + const int32_t num_sequences = + !sequence_groups_.empty() + ? std::accumulate(sequence_groups_.begin(), + sequence_groups_.end(), + 0, + [](int sum, const auto& group) { + return sum + group->sequences().size(); + }) + : 0; + const int32_t THREADPOOL_THRESHOLD = 16; + if (num_sequences == 0) { + return ForwardInput{}; + } + + // Get basic information of first sequence - optimize pointer access + Sequence* first_sequence = nullptr; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + first_sequence = sequence_groups_[0]->sequences()[0].get(); + } + + if (!first_sequence) { + return ForwardInput{}; + } + + const uint32_t seq_len = first_sequence->num_tokens(); + const uint32_t num_decoder_embeddings = + first_sequence->num_decoder_embeddings(); + const uint32_t n_prompt_tokens = first_sequence->num_prompt_tokens(); + const bool is_first_prefill = (first_sequence->num_generated_tokens() == 0); + // const uint64_t model_version = first_sequence->get_model_version(); + + // ========== High-performance encoder tokens construction ========== + auto build_encoder_tokens_optimized = [&]() + -> const std::vector& { + auto& cache_data = perf_cache.cache_data; + + // encoder doesn't use cache key, because encoder doesn't use encoder_tokens + // in non-first prefill scenarios, only uses encoder_seq_len + if (!is_first_prefill) { + return cache_data.encoder_tokens; + } + + // Optimization: Use SIMD-friendly memory access patterns + cache_data.encoder_tokens.clear(); + cache_data.encoder_seq_lens.clear(); + + // Optimization for scenarios where sequences have different lengths across + // sequence groups Pre-calculate total token count to avoid multiple memory + // reallocations + int32_t total_tokens = 0; + for (const auto& group_ptr : sequence_groups_) { + if (!group_ptr->sequences().empty()) { + // Sequences within group have same length, only need to get first + // sequence's length + const int32_t group_encoder_seq_len = + group_ptr->sequences()[0]->encoder_tokens().size(); + total_tokens += group_encoder_seq_len * group_ptr->sequences().size(); + } + } + + cache_data.encoder_tokens.reserve(total_tokens); + cache_data.encoder_seq_lens.resize(num_sequences); + cache_data.encoder_sparse_embeddings.clear(); + cache_data.encoder_sparse_embeddings.reserve(num_sequences); + cache_data.decoder_context_embeddings.clear(); + cache_data.decoder_context_embeddings.reserve(num_sequences); + + // Process by groups in batch + int32_t global_seq_idx = 0; + for (const auto& group_ptr : sequence_groups_) { + const auto& group = *group_ptr; + const int32_t group_size = group.sequences().size(); + + if (group_size == 0) continue; + + const int32_t group_encoder_seq_len = + group.sequences()[0]->encoder_seq_len(); + + // Batch set same values + std::fill_n(&cache_data.encoder_seq_lens[global_seq_idx], + group_size, + group_encoder_seq_len); + + // Batch copy tokens by sequence and collect sparse_embedding + for (const auto& sequence : group.sequences()) { + const auto& encoder_tokens = sequence->encoder_tokens(); + const int32_t* src_ptr = encoder_tokens.data(); + const int32_t group_encoder_seq_len = encoder_tokens.size(); + + // Use efficient batch insertion + if (group_encoder_seq_len > 0) { + cache_data.encoder_tokens.insert(cache_data.encoder_tokens.end(), + src_ptr, + src_ptr + group_encoder_seq_len); + } + // Collect sparse_embedding + auto mm_data = sequence->get_mm_data(); + auto sparse_embedding_optional = + mm_data.get(Sequence::ENCODER_SPARSE_EMBEDDING_NAME); + if (sparse_embedding_optional.has_value()) { + cache_data.encoder_sparse_embeddings.push_back( + sparse_embedding_optional.value()); + } + + auto decoder_context_embedding_optional = mm_data.get( + Sequence::DECODER_CONTEXT_EMBEDDING_NAME); + if (decoder_context_embedding_optional.has_value()) { + cache_data.decoder_context_embeddings.push_back( + decoder_context_embedding_optional.value()); + } + } + + global_seq_idx += group_size; + } + + return cache_data.encoder_tokens; + }; + + // ========== High-performance decoder data construction ========== + auto build_decoder_data_optimized = [&]() { + // Pre-allocate all containers to avoid dynamic expansion + const size_t total_tokens = num_sequences * seq_len; + std::vector flatten_tokens_vec; + flatten_tokens_vec.reserve(total_tokens); + std::vector sampling_params; + sampling_params.reserve(num_sequences); + std::vector selected_token_idxes; + selected_token_idxes.reserve(num_sequences); + std::vector sample_idxes; + sample_idxes.reserve(num_sequences); + std::vector> generated_tokens; + generated_tokens.reserve(num_sequences); + + // Multi-threading optimization: Use parallel processing when sequence count + // exceeds threshold and thread pool is available + ThreadPool* threadpool = thread_pool_; + if (num_sequences >= THREADPOOL_THRESHOLD && threadpool != nullptr) { + // Thread-safe result containers + std::vector> thread_flatten_tokens(num_sequences); + std::vector thread_sampling_params( + num_sequences); + std::vector thread_selected_token_idxes(num_sequences); + std::vector thread_sample_idxes(num_sequences); + std::vector> thread_generated_tokens(num_sequences); + + // Calculate thread allocation + const size_t num_threads = + std::min(static_cast(num_sequences), static_cast(16)); + const size_t sequences_per_thread = + (num_sequences + num_threads - 1) / num_threads; + + std::vector> futures; + std::vector>> promises; + futures.reserve(num_threads); + promises.reserve(num_threads); + + // Parallel processing function + auto process_sequences_range = [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; + i < end_idx && i < static_cast(num_sequences); + ++i) { + const Sequence* sequence = nullptr; + // Get sequence from sequence_groups + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + if (seq_idx + group->sequences().size() > i) { + sequence = group->sequences()[i - seq_idx].get(); + break; + } + seq_idx += group->sequences().size(); + } + + if (!sequence) continue; + + const auto& token_ids = sequence->tokens(); + + // Build generated tokens + auto& cur_generated_tokens = thread_generated_tokens[i]; + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + + // Build flatten tokens + auto& cur_flatten_tokens = thread_flatten_tokens[i]; + cur_flatten_tokens.reserve(seq_len); + cur_flatten_tokens.insert(cur_flatten_tokens.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Set sampling parameters + thread_sampling_params[i] = sequence->sampling_param(); + thread_sample_idxes[i] = static_cast(i); + } + }; + + // Launch parallel tasks + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_idx = thread_idx * sequences_per_thread; + size_t end_idx = std::min(start_idx + sequences_per_thread, + static_cast(num_sequences)); + + if (start_idx >= static_cast(num_sequences)) break; + + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + threadpool->schedule( + [process_sequences_range, start_idx, end_idx, promise]() mutable { + try { + process_sequences_range(start_idx, end_idx); + promise->set_value(); + } catch (...) { + promise->set_exception(std::current_exception()); + } + }); + } + + // Wait for all tasks to complete + for (auto& future : futures) { + future.get(); + } + + // Merge results + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + for (int32_t i = 0; i < num_sequences; ++i) { + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + thread_flatten_tokens[i].begin(), + thread_flatten_tokens[i].end()); + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(thread_sampling_params[i]); + sample_idxes.push_back(thread_sample_idxes[i]); + generated_tokens.push_back(std::move(thread_generated_tokens[i])); + } + } else { + // Original single-thread processing logic + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + for (const auto& sequence : group->sequences()) { + const auto& token_ids = sequence->tokens(); + + // Optimize generated tokens construction + auto& cur_generated_tokens = generated_tokens.emplace_back(); + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + // Optimize token processing - batch operations + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Simplify sampling parameter processing + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(sequence->sampling_param()); + sample_idxes.push_back(seq_idx); + seq_idx++; + } + } + } + + return std::make_tuple(std::move(flatten_tokens_vec), + std::move(sampling_params), + std::move(selected_token_idxes), + std::move(sample_idxes), + std::move(generated_tokens)); + }; + + // ========== Comprehensive parallel execution of optimized data construction + // ========== Use thread pool to execute all independent data construction + // tasks in parallel + std::future&> encoder_future; + std::future, + std::vector, + std::vector, + std::vector, + std::vector>>> + decoder_future; + + // Declare variables to store results + const std::vector* encoder_tokens_ptr = nullptr; + std::vector flatten_tokens_vec; + std::vector sampling_params; + std::vector selected_token_idxes; + std::vector sample_idxes; + std::vector> generated_tokens; + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Use ThreadPool's schedule method to execute independent tasks in parallel + // build_decoder_data_optimized handles multi-threading internally, no external + // parallel calls + + // Task 1: build_encoder_tokens_optimized + std::promise*> encoder_promise; + auto encoder_future = encoder_promise.get_future(); + thread_pool_->schedule([&, promise = std::move(encoder_promise)]() mutable { + const auto& result = build_encoder_tokens_optimized(); + promise.set_value(&result); + }); + // Wait for encoder to complete + encoder_tokens_ptr = encoder_future.get(); + // Task 2: build_decoder_data_optimized executes directly, handles + // multi-threading internally + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = build_decoder_data_optimized(); + } else { + // Single-thread execution (original logic) + encoder_tokens_ptr = &build_encoder_tokens_optimized(); + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = build_decoder_data_optimized(); + } + + const auto& encoder_tokens = *encoder_tokens_ptr; + + // ========== High-performance ForwardInput construction ========== + ForwardInput forward_input; + auto& input_params = forward_input.input_params; + auto& onerec_params = input_params.mutable_onerec_params(); + auto& cache_data = perf_cache.cache_data; + + // Initialize key fields for asynchronous tasks + const int64_t bs = sequence_groups_.size(); + const int64_t group_width = + sequence_groups_.empty() ? 1 : sequence_groups_[0]->sequences().size(); + + std::vector> decoder_embedding_futures; + torch::Tensor result_embedding; + + // ========== Parallel tensor construction tasks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Only use parallelization for time-consuming tasks (token_ids and + // encoder_token_ids) + std::promise token_ids_promise; + std::promise encoder_token_ids_promise; + + auto token_ids_future = token_ids_promise.get_future(); + // auto encoder_token_ids_future = encoder_token_ids_promise.get_future(); + + // Task 1: Build token_ids tensor - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + thread_pool_->schedule([&flatten_tokens_vec, + promise = std::move(token_ids_promise)]() mutable { + try { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + auto tensor = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Task 2: Build encoder_token_ids tensor (if needed) - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + /* + thread_pool_->schedule( + [&encoder_tokens, + promise = std::move(encoder_token_ids_promise)]() mutable { + try { + torch::Tensor tensor; + if (!encoder_tokens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid + // clone operations + tensor = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + */ + if (!perf_cache.cache_data.decoder_context_embeddings.empty()) { + // Task 3: Synchronously process decoder_embedding, inner group dimension + // parallelization optimization + + // Optimization: Directly get shape information from first embedding to + // avoid torch::cat + auto first_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + + // Create tensor on pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + + // Calculate total sequence length, pre-allocate context_len + seq_len + int64_t total_seq_len = context_len + seq_len; + + auto combined_embedding = + torch::empty({bs, group_width, total_seq_len, hidden_size}, options); + + // High-performance optimization: group dimension segmented + // parallelization + void* dst_data = combined_embedding.data_ptr(); + + // Get element size (supports float, bfloat16 and other types) + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t group_stride = total_seq_len * hidden_size * element_size; + const size_t batch_stride = + group_width * total_seq_len * hidden_size * element_size; + + // Parallelization strategy: segment by group dimension, consistent with + // thread calculations elsewhere + const size_t num_threads = + std::min(static_cast(group_width), static_cast(16)); + const size_t groups_per_thread = + (group_width + num_threads - 1) / num_threads; + + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_group = thread_idx * groups_per_thread; + size_t end_group = std::min(start_group + groups_per_thread, + static_cast(group_width)); + + if (start_group >= static_cast(group_width)) break; + + std::promise promise; + decoder_embedding_futures.push_back(promise.get_future()); + + thread_pool_->schedule( + [start_group, + end_group, + bs, + dst_data, + context_len, + hidden_size, + element_size, + batch_stride, + group_stride, + context_size, + embeddings = perf_cache.cache_data.decoder_context_embeddings, + dst_tensor = combined_embedding, + promise = std::move(promise)]() mutable { + // Copy context_embedding for specified group range of each batch + for (int64_t b = 0; b < bs; ++b) { + // Optimization: Access corresponding batch embedding directly + // through index + const void* batch_src = embeddings[b].data_ptr(); + auto* batch_dst = + static_cast(dst_data) + b * batch_stride; + + for (size_t g = start_group; g < end_group; ++g) { + std::memcpy( + batch_dst + g * group_stride, batch_src, context_size); + } + } + promise.set_value(); + }); + } + + result_embedding = combined_embedding; + } + + // Task 4: Build sequence length vector - changed to serial execution (very + // time-consuming, ~0.001785ms) + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + // Task 5: Build encoder_seq_lens_tensor - changed to serial execution (less + // time-consuming) + torch::Tensor encoder_seq_lens_tensor; + if (!cache_data.encoder_seq_lens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + + // Set basic parameters simultaneously (not dependent on asynchronous tasks) + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + forward_input.positions = perf_cache.fixed_positions_tensor; + + // Wait and collect results + forward_input.token_ids = token_ids_future.get(); + // auto encoder_token_ids = encoder_token_ids_future.get(); + + // seq_lens has been changed to serial execution, use the constructed + // variable directly + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + // encoder_seq_lens_tensor has been changed to serial execution, use the + // constructed variable directly + if (encoder_seq_lens_tensor.defined()) { + onerec_params.encoder_seq_lens_tensor = + std::move(encoder_seq_lens_tensor); + onerec_params.encoder_seq_lens = cache_data.encoder_seq_lens; + } + onerec_params.encoder_positions = perf_cache.fixed_encoder_positions_tensor; + } else { + // Single-threaded execution (original logic) + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + forward_input.token_ids = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(forward_input.token_ids.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + forward_input.positions = perf_cache.fixed_positions_tensor; + + if (!encoder_tokens.empty()) { + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + onerec_params.encoder_token_ids = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(onerec_params.encoder_token_ids.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + onerec_params.encoder_positions = perf_cache.fixed_encoder_positions_tensor; + // Pre-allocate and batch fill + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + if (!cache_data.encoder_seq_lens.empty()) { + // Set OneRecModelInputParams encoder data + onerec_params.encoder_seq_lens = cache_data.encoder_seq_lens; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + onerec_params.encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(onerec_params.encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + } + + // ========== Parallel processing of independent code blocks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Define promise/future for parallel tasks + std::promise block_tables_promise; + auto block_tables_future = block_tables_promise.get_future(); + + // Task 1: Empty block tables processing - use thread pool (relatively + // time-consuming) + thread_pool_->schedule( + [&input_params, num_sequences, &block_tables_promise]() mutable { + try { + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special + // optimization for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + block_tables_promise.set_value(); + } catch (...) { + block_tables_promise.set_exception(std::current_exception()); + } + }); + + // Optimization: Merge small tasks into sequential execution to reduce + // thread switching overhead Cross-attention parameter construction - use + // placeholder + onerec_params.cross_attn_kv_cu_seq_lens = torch::zeros({1}, torch::kInt); + onerec_params.cross_attn_kv_cu_seq_lens_vec = {0}; + onerec_params.cross_attn_block_tables = torch::zeros({1, 1}, torch::kInt); + + // Sampling parameter processing + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // First prefill processing - use placeholder + if (is_first_prefill) { + // Use placeholder instead of complex cross_attn_new_cache_slots + // construction + onerec_params.cross_attn_new_cache_slots = torch::zeros({1}, torch::kInt); + } + + // Wait for parallel tasks to complete (only block_tables uses thread pool) + block_tables_future.wait(); + } else { + // ========== Non-parallel case: sequential processing ========== + // Optimize empty block tables processing + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special optimization + // for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + // ========== Cross-attention parameter construction (using placeholder) + // ========== Use placeholder tensor instead of actual data + onerec_params.cross_attn_kv_cu_seq_lens = torch::zeros({1}, torch::kInt); + onerec_params.cross_attn_kv_cu_seq_lens_vec = {0}; + + // Use placeholder tensor instead of actual data + onerec_params.cross_attn_block_tables = torch::zeros({1, 1}, torch::kInt); + + // ========== Optimize sampling parameter processing ========== + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // ========== First prefill processing (using placeholder) ========== + if (is_first_prefill) { + // Use placeholder tensor instead of actual data + onerec_params.cross_attn_new_cache_slots = torch::zeros({1}, torch::kInt); + } + } + + // ========== Common parameter settings ========== + // Batch set other parameters + input_params.embedding_ids.assign(num_sequences, 0); + +#if defined(USE_NPU) + auto prefill_indices = util::find_ones_indices(input_params.q_seq_lens_vec); + input_params.decode_seq_range = + std::make_pair(0, static_cast(flatten_tokens_vec.size())); +#else + input_params.decode_seq_range = { + 0, static_cast(flatten_tokens_vec.size())}; +#endif + + // OneRec model parameters + onerec_params.rec_stage = OneRecModelInputParams::RecStage::PREFILL; + onerec_params.is_hybrid_mode = false; + onerec_params.has_encoder_output = true; + onerec_params.is_first_prefill = is_first_prefill; + onerec_params.bs = bs; + onerec_params.group_width = group_width; + onerec_params.seq_len = seq_len; + onerec_params.encoder_max_seq_len = + cache_data.encoder_seq_lens.empty() + ? 0 + : *std::max_element(cache_data.encoder_seq_lens.begin(), + cache_data.encoder_seq_lens.end()); + + onerec_params.generated_tokens = std::move(generated_tokens); + + // Process sparse_embedding: Efficiently concatenate from cache_data + if (!perf_cache.cache_data.encoder_sparse_embeddings.empty()) { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + onerec_params.encoder_sparse_embedding = + torch::cat(perf_cache.cache_data.encoder_sparse_embeddings, /*dim=*/0); + } + + if (!perf_cache.cache_data.decoder_context_embeddings.empty()) { + // Get group_width + const int64_t group_width_val = onerec_params.group_width; + if (group_width_val == 1 && seq_len == 0) { + // Optimization: When bs==1, directly use the first embedding to avoid + // unnecessary torch::cat + if (bs == 1) { + onerec_params.decoder_context_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + } else { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + auto original_context_embedding = torch::cat( + perf_cache.cache_data.decoder_context_embeddings, /*dim=*/0); + onerec_params.decoder_context_embedding = original_context_embedding; + } + } else if (group_width_val == 1 && seq_len > 0) { + // Handle the scenario where group_width==1 and seq_len>0 + // Get information from the first embedding + const auto& first_embedding = + perf_cache.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + int64_t total_seq_len = context_len + seq_len; + + // Allocate a tensor of shape {bs, 1, total_seq_len, hidden_size}, + // optimized with pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + auto combined_embedding = + torch::empty({bs, 1, total_seq_len, hidden_size}, options); + + // Single-threaded copy of context_len portion of data + void* dst_data = combined_embedding.data_ptr(); + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t batch_stride = total_seq_len * hidden_size * element_size; + + // Copy context_embedding for each batch + for (int64_t b = 0; b < bs; ++b) { + const void* batch_src = + perf_cache.cache_data.decoder_context_embeddings[b].data_ptr(); + auto* batch_dst = static_cast(dst_data) + b * batch_stride; + std::memcpy(batch_dst, batch_src, context_size); + } + onerec_params.decoder_context_embedding = combined_embedding; + } else { + for (auto& future : decoder_embedding_futures) { + future.get(); + } + onerec_params.decoder_context_embedding = std::move(result_embedding); + } + } + + return forward_input; +} + +} // namespace xllm diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.h b/xllm/core/framework/batch/onerec_batch_input_builder.h new file mode 100644 index 000000000..ac5179d88 --- /dev/null +++ b/xllm/core/framework/batch/onerec_batch_input_builder.h @@ -0,0 +1,119 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/mm_data.h" +#include "framework/request/sequence.h" +#include "framework/request/sequences_group.h" +#include "rec_batch_input_builder.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class OneRecBatchInputBuilder : public RecBatchInputBuilder { + public: + explicit OneRecBatchInputBuilder( + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + const uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); + + public: + ForwardInput build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) override; + + private: + const std::vector& sequence_groups_; + const std::vector& allowed_max_tokens_; + const std::vector& input_embeddings_vec_; + const std::vector& mm_data_vec_; + std::vector* swap_block_transfer_infos_ = nullptr; + const uint64_t batch_id_; + const ModelArgs* args_ = nullptr; + ThreadPool* thread_pool_ = nullptr; + + // High performance cache system + struct HighPerformanceCache { + // Memory pool - avoid frequent allocation/deallocation + struct MemoryPool { + std::vector> int32_pools; + size_t pool_index = 0; + + std::vector& get_int32_vector(size_t reserve_size = 0) { + if (pool_index >= int32_pools.size()) { + int32_pools.emplace_back(); + } + auto& vec = int32_pools[pool_index++]; + vec.clear(); + if (reserve_size > 0) vec.reserve(reserve_size); + return vec; + } + + void reset() { pool_index = 0; } + }; + + // Cache data structure + struct CacheData { + std::vector encoder_tokens; + std::vector encoder_seq_lens; + std::vector encoder_sparse_embeddings; + std::vector decoder_context_embeddings; + }; + + // Pre-created constant tensors - lazy initialized to avoid static + // initialization order issues + torch::Tensor fixed_positions_tensor; + torch::Tensor fixed_encoder_positions_tensor; + torch::Tensor empty_tensor; + bool tensors_initialized = false; + + MemoryPool memory_pool; + CacheData cache_data; + + // Default constructor - does NOT create tensors to avoid static + // initialization order fiasco + HighPerformanceCache() = default; + + // Lazy initialization of tensors - must be called before first use + void ensure_tensors_initialized() { + if (!tensors_initialized) { + fixed_positions_tensor = torch::tensor({0}, torch::kInt); + fixed_encoder_positions_tensor = torch::tensor({0}, torch::kInt); + empty_tensor = torch::tensor(std::vector{}, torch::kInt); + tensors_initialized = true; + } + } + }; + + // Use function-local static to ensure proper initialization order + // (Meyers' Singleton pattern) + static HighPerformanceCache& get_perf_cache(); +}; + +} // namespace xllm diff --git a/xllm/core/framework/batch/rec_batch_input_builder.cpp b/xllm/core/framework/batch/rec_batch_input_builder.cpp new file mode 100644 index 000000000..b86449e2b --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.cpp @@ -0,0 +1,58 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_batch_input_builder.h" + +#include + +#include +#include + +#include "onerec_batch_input_builder.h" + +namespace xllm { + +std::unique_ptr RecBatchInputBuilder::create( + RecType rec_type, + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool) { + switch (rec_type) { + case RecType::kOneRec: + return std::make_unique( + sequence_groups, + allowed_max_tokens, + input_embeddings_vec, + mm_data_vec, + swap_block_transfer_infos, + batch_id, + args, + thread_pool); + case RecType::kNone: + case RecType::kLlmRec: + break; + } + + LOG(FATAL) << "Unsupported RecType for RecBatchInputBuilder: " + << static_cast(rec_type); + return nullptr; +} + +} // namespace xllm diff --git a/xllm/core/framework/batch/rec_batch_input_builder.h b/xllm/core/framework/batch/rec_batch_input_builder.h new file mode 100644 index 000000000..831bba598 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.h @@ -0,0 +1,53 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +#include "framework/model/model_args.h" +#include "framework/request/mm_data.h" +#include "framework/request/rec_type.h" +#include "framework/request/sequences_group.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecBatchInputBuilder { + public: + virtual ~RecBatchInputBuilder() = default; + + virtual ForwardInput build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) = 0; + + static std::unique_ptr create( + RecType rec_type, + const std::vector& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + std::vector* swap_block_transfer_infos, + uint64_t batch_id, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); +}; + +} // namespace xllm diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 26578aa4b..4d8d96d93 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -34,6 +34,85 @@ limitations under the License. namespace xllm { +namespace { +constexpr size_t kDecoderBosTokenCount = 1; +constexpr size_t kDecoderMaxTokenCount = 4; +} // namespace + +const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; +const std::string Sequence::DECODER_CONTEXT_EMBEDDING_NAME = + "decoder_context_embedding"; + +void Sequence::init_onerec_sequence( + const std::vector& prompt_token_ids, + torch::Tensor input_embedding) { + auto& onerec_state = onerec_state_.emplace(); + if (!prompt_token_ids.empty()) { + onerec_state.encoder_tokens.assign(prompt_token_ids.begin(), + prompt_token_ids.end()); + onerec_state.num_encoder_tokens = prompt_token_ids.size(); + } else { + auto encoder_sparse_embedding = + mm_data_.get(ENCODER_SPARSE_EMBEDDING_NAME); + CHECK(encoder_sparse_embedding.has_value()) + << "encoder sparse embedding not found in mm_data"; + onerec_state.num_encoder_tokens = encoder_sparse_embedding.value().size(0); + } + + auto decoder_context_embedding = + mm_data_.get(DECODER_CONTEXT_EMBEDDING_NAME); + + size_t capacity = kDecoderMaxTokenCount; + if (decoder_context_embedding.has_value()) { + num_prompt_tokens_ = 0; + onerec_state.num_decoder_embeddings = + decoder_context_embedding.value().size(0); + capacity = onerec_state.num_decoder_embeddings + capacity - + kDecoderBosTokenCount; + } else { + num_prompt_tokens_ = kDecoderBosTokenCount; + } + + tokens_.resize(capacity); + for (size_t i = 0; i < num_prompt_tokens_; ++i) { + tokens_[num_tokens_++] = sequence_params_.bos_token_id; + token_to_count_map_[sequence_params_.bos_token_id]++; + } + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = std::move(input_embedding); + cur_generated_token_idx_ = num_prompt_tokens_; + logprob_state_ = std::make_unique(num_prompt_tokens_, capacity); +} + +SequenceOutput Sequence::build_onerec_output(const Slice& ids, + size_t size, + SequenceOutput output) const { + output.token_ids = ids.slice(num_prompt_tokens_, size); + return output; +} + +SequenceOutput Sequence::build_onerec_streaming_output( + const Slice& ids, + size_t size) const { + SequenceOutput output; + output.index = index_; + output.token_ids = ids.slice(num_prompt_tokens_, size); + return output; +} + +SequenceOutput Sequence::generate_onerec_output(const Slice& ids, + size_t size) const { + SequenceOutput output; + output.index = index_; + if (output_embedding_.defined()) { + output.embedding = output_embedding_; + } + if (finish_reason_ != FinishReason::NONE) { + output.finish_reason = finish_reason_.to_string(); + } + return build_onerec_output(ids, size, std::move(output)); +} + Sequence::Sequence(size_t index, const std::vector& prompt_token_ids, torch::Tensor input_embedding, @@ -45,7 +124,13 @@ Sequence::Sequence(size_t index, latest_generate_time_(absl::Now()), sequence_params_(seq_params), decoder_(std::move(decoder)), - termination_flag_(std::make_shared>(false)) { + termination_flag_(std::make_shared>(false)), + rec_type_(seq_params.rec_type) { + if (is_onerec_model()) { + init_onerec_sequence(prompt_token_ids, std::move(input_embedding)); + return; + } + CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; auto capacity = sequence_params_.seq_capacity; CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; @@ -85,6 +170,8 @@ Sequence::Sequence(const Sequence& other) num_tokens_(other.num_tokens_), token_to_count_map_(other.token_to_count_map_), num_prompt_tokens_(other.num_prompt_tokens_), + onerec_state_(other.onerec_state_), + rec_type_(other.rec_type_), volatile_num_prompt_tokens_(other.volatile_num_prompt_tokens_), embedding_id_(other.embedding_id_), finished_(other.finished_), @@ -103,8 +190,10 @@ void Sequence::append_token(const Token& token) { CHECK_LT(num_tokens_, tokens_.size()) << "exceed the token capacity of the sequence"; CHECK(!finished_) << "cannot append token to a finished sequence"; - CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage()) - << "cannot append token to a prefill sequence"; + if (!is_onerec_model()) { + CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage()) + << "cannot append token to a prefill sequence"; + } if (!sequence_params_.enable_schedule_overlap) { // check if the token is the first token after the prompt @@ -250,6 +339,10 @@ std::optional Sequence::generate_streaming_output( AUTO_COUNTER(detokenization_latency_seconds_stream); const auto ids = Slice(tokens_, size); + if (is_onerec_model()) { + return build_onerec_streaming_output(ids, size); + } + // record the start index of token ids const size_t start = decoder_.output_offset(); auto delta = decoder_.decode(ids, tokenizer); @@ -333,6 +426,19 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { } } + if (is_onerec_model()) { + return generate_onerec_output(ids, size); + } + + SequenceOutput output; + output.index = index_; + if (output_embedding_.defined()) { + output.embedding = output_embedding_; + } + if (finish_reason_ != FinishReason::NONE) { + output.finish_reason = finish_reason_.to_string(); + } + // record the start index of token ids const size_t start = decoder_.output_offset(); @@ -349,16 +455,7 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { ss << decoder_.decode(ids.slice(0, end), tokenizer); } - SequenceOutput output; - output.index = index_; output.text = ss.str(); - if (output_embedding_.defined()) { - output.embedding = output_embedding_; - } - - if (finish_reason_ != FinishReason::NONE) { - output.finish_reason = finish_reason_.to_string(); - } const size_t end = decoder_.output_offset(); output.token_ids = ids.slice(start, end); @@ -402,6 +499,10 @@ bool Sequence::finished() const { return finished_; } + if (is_onerec_model() && num_tokens_ == num_prompt_tokens_) { + return false; + } + // Embedding sequence never be finished until it updates its embeddings if (finish_status_invalidated_ && sequence_params_.sampling_param->is_embeddings) { diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index ee9b3f210..cf12f0d23 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include +#include #include #include "core/common/types.h" @@ -30,6 +32,7 @@ limitations under the License. #include "framework/block/block.h" #include "incremental_decoder.h" #include "mm_data.h" +#include "rec_type.h" #include "request_output.h" #include "sequence_kv_state.h" #include "sequence_logprob_state.h" @@ -73,6 +76,10 @@ struct SequenceParams { // enable_schedule_overlap or not. default = false. bool enable_schedule_overlap = false; + RecType rec_type = RecType::kNone; + + int32_t bos_token_id = 0; + // sampling params // reference from request RequestSamplingParam* sampling_param; // not owned @@ -268,11 +275,57 @@ class Sequence final { // get sequence id int32_t seq_id() const { return seq_id_; } + const std::vector& encoder_tokens() const { + static const std::vector kEmpty; + if (!onerec_state_.has_value()) { + return kEmpty; + } + return onerec_state_->encoder_tokens; + } + + size_t encoder_seq_len() const { + return onerec_state_.has_value() ? onerec_state_->num_encoder_tokens + : 0; + } + + size_t num_decoder_embeddings() const { + return onerec_state_.has_value() ? onerec_state_->num_decoder_embeddings + : 0; + } + + RecType rec_type() const { return rec_type_; } + + bool is_rec_request() const { return rec_type_ != RecType::kNone; } + + bool is_onerec_model() const { return rec_type_ == RecType::kOneRec; } + + static const std::string ENCODER_SPARSE_EMBEDDING_NAME; + static const std::string DECODER_CONTEXT_EMBEDDING_NAME; + void set_cancel() { cancelled_.store(true, std::memory_order_relaxed); } bool cancelled() const { return cancelled_.load(std::memory_order_relaxed); } private: + void init_onerec_sequence(const std::vector& prompt_token_ids, + torch::Tensor input_embedding); + + SequenceOutput build_onerec_output(const Slice& ids, + size_t size, + SequenceOutput output) const; + + SequenceOutput build_onerec_streaming_output(const Slice& ids, + size_t size) const; + + SequenceOutput generate_onerec_output(const Slice& ids, + size_t size) const; + + struct OneRecState { + size_t num_encoder_tokens = 0; + size_t num_decoder_embeddings = 0; + std::vector encoder_tokens; + }; + // the index of the sequence in the request size_t index_ = 0; @@ -319,6 +372,10 @@ class Sequence final { // the length of the prompt tokens size_t num_prompt_tokens_ = 0; + std::optional onerec_state_; + + RecType rec_type_ = RecType::kNone; + // NOTE: MUST FIXME Later // record all tokens num in last turn when the request is // interrupted due to the lack of kv cache capacity. diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index a0054bb23..a99d4bd50 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -22,6 +22,9 @@ cc_library( dit_worker.h embed_worker_impl.h embed_vlm_worker_impl.h + rec_worker_impl.h + llmrec_worker_impl.h + onerec_worker_impl.h worker_client.h xservice_client.h speculative_worker_impl.h @@ -38,6 +41,9 @@ cc_library( dit_worker.cpp embed_worker_impl.cpp embed_vlm_worker_impl.cpp + rec_worker_impl.cpp + llmrec_worker_impl.cpp + onerec_worker_impl.cpp worker_client.cpp xservice_client.cpp params_utils.cpp diff --git a/xllm/core/runtime/llmrec_worker_impl.cpp b/xllm/core/runtime/llmrec_worker_impl.cpp new file mode 100644 index 000000000..83ec334ba --- /dev/null +++ b/xllm/core/runtime/llmrec_worker_impl.cpp @@ -0,0 +1,131 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llmrec_worker_impl.h" + +#include +#include + +#include +#include + +#include "common/types.h" +#include "core/layers/word_embedding.h" + +namespace xllm { + +LlmRecWorkerImpl::LlmRecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : RecWorkerImpl(parallel_args, device, options) {} + +void LlmRecWorkerImpl::prepare_work_before_execute( + const ForwardInput& inputs, + ForwardInput& processed_inputs) { + WorkerImpl::prepare_work_before_execute(inputs, processed_inputs); + + if (!inputs.input_params.mm_data.valid()) { + return; + } + + torch::Tensor input_embedding; + torch::Tensor input_tokens_tensor; + torch::Tensor input_indices_tensor; + + const auto& mm_data = inputs.input_params.mm_data; + const auto& processed_mm_data = processed_inputs.input_params.mm_data; + + if (auto res = + processed_mm_data.get(LLM_REC_INPUT_TOKENS)) { + input_tokens_tensor = res.value(); + } + + // input indices 需要在 Host 侧生成位置索引 + if (auto res = mm_data.get(LLM_REC_INPUT_INDICES)) { + input_indices_tensor = res.value(); + } + + if (auto res = + processed_mm_data.get(LLM_REC_INPUT_EMBEDDING)) { + input_embedding = res.value(); + } + + if (input_embedding.defined()) { + input_embedding = input_embedding.to(dtype()); + } + + if (input_indices_tensor.defined()) { + layer::WordEmbedding word_embedding = get_word_embedding(); + torch::Tensor input_tokens_embedding = + word_embedding(input_tokens_tensor, 0); + + if (input_embedding.defined()) { + std::vector input_indices( + input_indices_tensor.data_ptr(), + input_indices_tensor.data_ptr() + input_indices_tensor.numel()); + + processed_inputs.input_params.input_embedding = + merge_embeddings_by_indices( + input_tokens_embedding, input_embedding, input_indices); + } else { + processed_inputs.input_params.input_embedding = input_tokens_embedding; + } + } else if (input_embedding.defined()) { + processed_inputs.input_params.input_embedding = input_embedding; + } +} + +torch::Tensor LlmRecWorkerImpl::merge_embeddings_by_indices( + const torch::Tensor& input_tokens_embedding, + const torch::Tensor& input_embedding, + const std::vector& input_indices) { + CHECK_EQ(input_embedding.dim(), 2); + CHECK_EQ(input_tokens_embedding.dim(), 2); + CHECK_EQ(input_tokens_embedding.size(1), input_embedding.size(1)); + CHECK_EQ(input_tokens_embedding.dtype(), input_embedding.dtype()); + CHECK_EQ(input_tokens_embedding.device(), input_embedding.device()); + + const int64_t total_rows = + input_tokens_embedding.size(0) + input_embedding.size(0); + const int64_t cols = input_embedding.size(1); + + torch::Device device = input_embedding.device(); + torch::Tensor merged = torch::empty( + {total_rows, cols}, torch::dtype(input_embedding.dtype()).device(device)); + + std::vector input_embedding_indices; + for (int i = 0; i < static_cast(total_rows); ++i) { + if (std::find(input_indices.begin(), input_indices.end(), i) == + input_indices.end()) { + input_embedding_indices.push_back(i); + } + } + + CHECK_EQ(input_embedding_indices.size(), input_embedding.size(0)); + + torch::Tensor input_embedding_indices_tensor = + torch::tensor(input_embedding_indices, torch::kInt64).to(device); + merged.index_put_({input_embedding_indices_tensor, torch::indexing::Ellipsis}, + input_embedding); + + torch::Tensor input_indices_tensor = + torch::tensor(input_indices, torch::kInt64).to(device); + merged.index_put_({input_indices_tensor, torch::indexing::Ellipsis}, + input_tokens_embedding); + + return merged; +} + +} // namespace xllm diff --git a/xllm/core/runtime/llmrec_worker_impl.h b/xllm/core/runtime/llmrec_worker_impl.h new file mode 100644 index 000000000..02f263e6d --- /dev/null +++ b/xllm/core/runtime/llmrec_worker_impl.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "runtime/rec_worker_impl.h" + +namespace xllm { + +class LlmRecWorkerImpl final : public RecWorkerImpl { + public: + LlmRecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + ~LlmRecWorkerImpl() override = default; + + void prepare_work_before_execute(const ForwardInput& inputs, + ForwardInput& processed_inputs) override; + + private: + torch::Tensor merge_embeddings_by_indices( + const torch::Tensor& input_tokens_embedding, + const torch::Tensor& input_embedding, + const std::vector& input_indices); +}; + +} // namespace xllm diff --git a/xllm/core/runtime/onerec_worker_impl.cpp b/xllm/core/runtime/onerec_worker_impl.cpp new file mode 100644 index 000000000..c0a35e84d --- /dev/null +++ b/xllm/core/runtime/onerec_worker_impl.cpp @@ -0,0 +1,132 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "onerec_worker_impl.h" + +#include + +#include + +#include "common/device_monitor.h" +#include "common/metrics.h" +#include "framework/model/model_input_params.h" +#include "util/timer.h" + +namespace xllm { + +OneRecWorkerImpl::OneRecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : RecWorkerImpl(parallel_args, device, options) {} + +std::optional OneRecWorkerImpl::step(const ForwardInput& input) { + Timer timer; + device_.set_device(); + + const auto& sampling_params = input.sampling_params; + const auto& input_params = input.input_params; + + if (!input_params.rec_params.has_value()) { + LOG(ERROR) << "OneRecWorkerImpl requires rec_params."; + return std::nullopt; + } + + const auto& rec_params = input_params.rec_params.value(); + + torch::Tensor hidden_states; + if (rec_params.rec_stage == RecModelInputParams::RecStage::PREFILL) { + if (!rec_params.is_first_prefill) { + ModelInputParams decoder_params = input_params; + decoder_params.rec_params->is_encoder_forward = false; + hidden_states = model_executor_->forward( + input.token_ids, input.positions, kv_caches_, decoder_params); + } else { + const bool has_sparse_embedding = + rec_params.encoder_sparse_embedding.defined(); + const bool has_encoder_tokens = rec_params.encoder_token_ids.defined() && + rec_params.encoder_positions.defined(); + + if (!has_sparse_embedding && !has_encoder_tokens) { + LOG(ERROR) << "OneRecWorkerImpl first prefill requires encoder inputs."; + return std::nullopt; + } + + ModelInputParams encoder_params = input_params; + encoder_params.rec_params->is_encoder_forward = true; + + torch::Tensor encoder_tokens; + if (has_sparse_embedding) { + encoder_params.rec_params->is_hybrid_mode = true; + encoder_tokens = rec_params.encoder_sparse_embedding; + } else { + encoder_tokens = rec_params.encoder_token_ids; + } + + model_executor_->forward(encoder_tokens, + rec_params.encoder_positions, + kv_caches_, + encoder_params); + + ModelInputParams decoder_params = input_params; + decoder_params.rec_params->is_encoder_forward = false; + hidden_states = model_executor_->forward( + input.token_ids, input.positions, kv_caches_, decoder_params); + } + } else { + ModelInputParams decoder_params = input_params; + decoder_params.rec_params->is_encoder_forward = false; + hidden_states = model_executor_->forward( + input.token_ids, input.positions, kv_caches_, decoder_params); + } + + if (!hidden_states.defined()) { + return std::nullopt; + } + + if (!enable_schedule_overlap() && !driver_ && !dp_driver_ && + !options_.enable_speculative_decode()) { + device_.synchronize_default_stream(); + COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); + DeviceMonitor::get_instance().update_active_activation_memory( + device_.index()); + return std::nullopt; + } + + torch::Tensor logits; + if (sampling_params.selected_token_idxes.defined()) { + logits = + model_->logits(hidden_states, sampling_params.selected_token_idxes); + } + + ForwardOutput output; + + if (sampling_params.selected_token_idxes.defined()) { + auto sample_output = sampler_->forward(logits, sampling_params); + output.logits = logits; + output.sample_output = sample_output; + output.do_sample = sampling_params.do_sample; + output.logprobs = sampling_params.logprobs; + output.max_top_logprobs = sampling_params.max_top_logprobs; + } + + device_.synchronize_default_stream(); + COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); + DeviceMonitor::get_instance().update_active_activation_memory( + device_.index()); + + return output; +} + +} // namespace xllm diff --git a/xllm/core/runtime/onerec_worker_impl.h b/xllm/core/runtime/onerec_worker_impl.h new file mode 100644 index 000000000..18643ad28 --- /dev/null +++ b/xllm/core/runtime/onerec_worker_impl.h @@ -0,0 +1,37 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "runtime/rec_worker_impl.h" + +namespace xllm { + +class OneRecWorkerImpl final : public RecWorkerImpl { + public: + OneRecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + ~OneRecWorkerImpl() override = default; + + std::optional step(const ForwardInput& input) override; +}; + +} // namespace xllm diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index 41ddf4b1e..1abce7ce2 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -38,6 +38,9 @@ struct Options { // model backend PROPERTY(std::string, backend); + // rec 模型类型提示(例如 onerec/qwen2/qwen3),用于 WorkerType::REC 的实现选择 + PROPERTY(std::optional, rec_model_type); + // devices for execute model PROPERTY(std::vector, devices); diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp new file mode 100644 index 000000000..71caef904 --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -0,0 +1,53 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_worker_impl.h" + +#include +#include + +#include "util/env_var.h" + +namespace xllm { + +RecWorkerImpl::RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : LLMWorkerImpl(parallel_args, device, options) { + if (!is_driver()) { + return; + } + + const int64_t num_threads = std::max( + 1, util::get_int_env("XLLM_REC_INPUT_BUILDER_THREADS", 16)); + input_builder_thread_pool_ = + std::make_shared(static_cast(num_threads)); +} + +bool RecWorkerImpl::init_model(ModelContext& context) { + return LLMWorkerImpl::init_model(context); +} + +ForwardInput RecWorkerImpl::prepare_inputs(Batch& batch) { + ThreadPool* thread_pool = + input_builder_thread_pool_ ? input_builder_thread_pool_.get() : nullptr; + + return batch.prepare_rec_forward_input(options_.num_decoding_tokens(), + /*min_decoding_batch_size=*/0, + context_.get_model_args(), + thread_pool); +} + +} // namespace xllm diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h new file mode 100644 index 000000000..6d4339c82 --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "runtime/llm_worker_impl.h" + +namespace xllm { + +class RecWorkerImpl : public LLMWorkerImpl { + public: + RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + ~RecWorkerImpl() override = default; + + bool init_model(ModelContext& context) override; + + ForwardInput prepare_inputs(Batch& batch) override; + + protected: + std::shared_ptr input_builder_thread_pool_; +}; + +} // namespace xllm diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 214c38f52..365f849aa 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -26,12 +26,15 @@ limitations under the License. #include #include "common/metrics.h" +#include "common/rec_model_utils.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/state_dict/state_dict.h" #include "runtime/embed_vlm_worker_impl.h" #include "runtime/embed_worker_impl.h" #include "runtime/llm_worker_impl.h" +#include "runtime/llmrec_worker_impl.h" +#include "runtime/onerec_worker_impl.h" #include "runtime/speculative_worker_impl.h" #include "runtime/vlm_worker_impl.h" #include "util/timer.h" @@ -52,10 +55,19 @@ Worker::Worker(const ParallelArgs& parallel_args, } else if (worker_type == WorkerType::EVLM) { impl_ = new EmbedVLMWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::REC) { - // TODO. add following when next pr (use RecWorkerImpl). - // impl_ = new RecWorkerImpl(parallel_args, device, options); - // TODO. delete this when next pr. - impl_ = new LLMWorkerImpl(parallel_args, device, options); + const auto& rec_model_type = options.rec_model_type(); + CHECK(rec_model_type.has_value()) + << "rec_model_type is required for REC worker"; + + const RecModelKind kind = get_rec_model_kind(rec_model_type.value()); + CHECK_NE(kind, RecModelKind::kNone) + << "Unsupported rec_model_type: " << rec_model_type.value(); + + if (kind == RecModelKind::kOneRec) { + impl_ = new OneRecWorkerImpl(parallel_args, device, options); + } else { + impl_ = new LlmRecWorkerImpl(parallel_args, device, options); + } } else { LOG(ERROR) << "Unknown worker type, please check logic"; } diff --git a/xllm/core/scheduler/fixed_steps_scheduler.cpp b/xllm/core/scheduler/fixed_steps_scheduler.cpp index 0d0d51411..e95f15091 100644 --- a/xllm/core/scheduler/fixed_steps_scheduler.cpp +++ b/xllm/core/scheduler/fixed_steps_scheduler.cpp @@ -33,7 +33,6 @@ limitations under the License. #include "framework/request/sequence.h" namespace xllm { - FixedStepsScheduler::FixedStepsScheduler(Engine* engine, const Options& options) : ContinuousScheduler(engine, options) {}