diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 3a21d768b..f786437d2 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -27,8 +27,7 @@ limitations under the License. #include "core/common/metrics.h" #include "core/distributed_runtime/dit_master.h" #include "core/distributed_runtime/llm_master.h" -// TODO. add following when next pr. -// #include "core/runtime/rec_master.h" +#include "core/distributed_runtime/rec_master.h" #include "core/distributed_runtime/vlm_master.h" #include "core/util/closure_guard.h" #include "embedding.pb.h" @@ -73,8 +72,6 @@ APIService::APIService(Master* master, std::make_unique( dynamic_cast(master), model_names); } else if (FLAGS_backend == "rec") { - // TODO. delete this when next pr. - using RecMaster = LLMMaster; rec_completion_service_impl_ = std::make_unique( dynamic_cast(master), model_names); } diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp index 8fade9e6b..9ffc4261b 100644 --- a/xllm/api_service/rec_completion_service_impl.cpp +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -27,10 +27,9 @@ limitations under the License. #include "common/instance_name.h" #include "completion.pb.h" #include "core/distributed_runtime/llm_master.h" +#include "core/distributed_runtime/rec_master.h" #include "core/framework/request/mm_data.h" #include "core/framework/request/request_output.h" -// TODO. add following when next pr. -// #include "core/runtime/rec_master.h" #include "core/util/utils.h" #define likely(x) __builtin_expect(!!(x), 1) @@ -89,9 +88,7 @@ bool send_result_to_client_brpc_rec(std::shared_ptr call, // Add rec specific output tensors auto output_tensor = response.mutable_output_tensors()->Add(); output_tensor->set_name("rec_result"); - // TODO: add following when next pr. - // if (FLAGS_enable_constrained_decoding) { - if (true) { + if (FLAGS_enable_constrained_decoding) { output_tensor->set_datatype(proto::DataType::INT64); output_tensor->mutable_shape()->Add(req_output.outputs.size()); output_tensor->mutable_shape()->Add(1); // Single item per output @@ -190,11 +187,8 @@ void RecCompletionServiceImpl::process_async_impl( master_->handle_request( std::move(rpc_request_ref.prompt()), std::move(prompt_tokens), - // TODO. add following when next pr. - // std::move(mm_data), + std::move(mm_data), std::move(request_params), - // TODO. delete this when next pr. - call.get(), [call, model, master = master_, diff --git a/xllm/api_service/rec_completion_service_impl.h b/xllm/api_service/rec_completion_service_impl.h index e383a103f..df627bcac 100644 --- a/xllm/api_service/rec_completion_service_impl.h +++ b/xllm/api_service/rec_completion_service_impl.h @@ -19,6 +19,7 @@ limitations under the License. #include "api_service_impl.h" #include "completion.pb.h" +#include "core/distributed_runtime/rec_master.h" #include "rec.pb.h" #include "stream_call.h" @@ -27,10 +28,6 @@ namespace xllm { using CompletionCall = StreamCall; -// TODO. add following when next pr. -// class RecMaster; -using RecMaster = LLMMaster; - // a class to handle completion requests class RecCompletionServiceImpl final : public APIServiceImpl { public: @@ -45,4 +42,4 @@ class RecCompletionServiceImpl final : public APIServiceImpl { RecMaster* master_ = nullptr; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 0959d1a6c..ccb4a08e5 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -164,6 +164,11 @@ DEFINE_int32( 256, "Max decode token per sequence which used for ZeroEvictionScheduler."); +// for rec, it's better to set to 100; +DEFINE_int32(request_queue_size, + 100000, + "The request queue size of the scheduler"); + // --- parallel config --- DEFINE_int32(dp_size, 1, "Data parallel size for MLA attention."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7bcd8043c..3135a4410 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -187,6 +187,8 @@ DECLARE_bool(enable_latency_aware_schedule); DECLARE_int32(profile_max_prompt_length); +DECLARE_int32(request_queue_size); + DECLARE_bool(enable_profile_kv_blocks); DECLARE_bool(disable_ttft_profiling); diff --git a/xllm/core/distributed_runtime/CMakeLists.txt b/xllm/core/distributed_runtime/CMakeLists.txt index 7a81dd749..34cd0300e 100644 --- a/xllm/core/distributed_runtime/CMakeLists.txt +++ b/xllm/core/distributed_runtime/CMakeLists.txt @@ -21,6 +21,8 @@ cc_library( vlm_engine.h vlm_master.h speculative_engine.h + rec_engine.h + rec_master.h disagg_pd_service.h disagg_pd_service_impl.h pd_ooc_service.h @@ -40,6 +42,8 @@ cc_library( vlm_engine.cpp vlm_master.cpp speculative_engine.cpp + rec_engine.cpp + rec_master.cpp disagg_pd_service.cpp disagg_pd_service_impl.cpp pd_ooc_service.cpp diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 5d5c9b248..9fe82e593 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -34,6 +34,8 @@ limitations under the License. #include "llm_engine.h" #include "llm_master.h" #include "models/model_registry.h" +#include "rec_engine.h" +#include "rec_master.h" #include "speculative_engine.h" #include "util/device_name_utils.h" #include "util/scope_guard.h" @@ -231,6 +233,35 @@ Master::Master(const Options& options, EngineType type) : options_(options) { eng_options.device_ip(options_.device_ip().value()); } engine_ = std::make_unique(eng_options); + } else if (type == EngineType::REC) { + options_.enable_schedule_overlap(false); + LOG(WARNING) << "Force to disable schedule overlap for REC model, not " + "supported yet."; + runtime::Options eng_options; + eng_options.model_path(options_.model_path()) + .devices(devices) + .backend(options_.backend()) + .block_size(options_.block_size()) + .max_cache_size(options_.max_cache_size()) + .max_memory_utilization(options_.max_memory_utilization()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .task_type(options_.task_type()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .enable_offline_inference(options_.enable_offline_inference()) + .spawn_worker_path(options_.spawn_worker_path()) + .enable_shm(options_.enable_shm()) + .is_local(options_.is_local()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .master_node_addr(options_.master_node_addr()) + .nnodes(options_.nnodes()) + .node_rank(options_.node_rank()) + .dp_size(options_.dp_size()) + .ep_size(options_.ep_size()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()); + + engine_ = std::make_unique(eng_options); } else { LOG(WARNING) << "Not supported llm engine type: " << static_cast(type); @@ -246,6 +277,9 @@ std::unique_ptr create_master(const std::string& backend, } else if (backend == "dit") { LOG(INFO) << "creating dit master"; return std::make_unique(options); + } else if (backend == "rec") { + LOG(INFO) << "creating rec master"; + return std::make_unique(options); } else { LOG(FATAL) << "Failed to create master, backend is" << backend; return nullptr; diff --git a/xllm/core/distributed_runtime/rec_engine.cpp b/xllm/core/distributed_runtime/rec_engine.cpp new file mode 100644 index 000000000..5977f6585 --- /dev/null +++ b/xllm/core/distributed_runtime/rec_engine.cpp @@ -0,0 +1,342 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_engine.h" + +#include + +#include +#include + +#include "common/metrics.h" +#include "framework/model/model_args.h" +#include "framework/model_loader.h" +#include "framework/parallel_state/parallel_state.h" +#include "util/pretty_print.h" +#include "util/timer.h" +#include "util/utils.h" + +namespace xllm { + +RecEngine::RecEngine(const runtime::Options& options) : options_(options) { + const auto& devices = options_.devices(); + CHECK_GT(devices.size(), 0) << "At least one device is required"; + + CHECK(!devices[0].is_cpu()) << "CPU device is not supported"; + const auto device_type = devices[0].type(); + for (const auto device : devices) { + CHECK_EQ(device.type(), device_type) + << "All devices should be the same type"; + } + + // initialize process groups if there are multiple devices + if (devices.size() > 1) { + // create a process group for each device if there are multiple gpus + process_groups_ = parallel_state::create_npu_process_groups(devices); + } + + WorkerType worker_type = WorkerType::REC; + const int32_t world_size = static_cast(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + const int32_t rank = static_cast(i); + ProcessGroup* pg = world_size > 1 ? process_groups_[i].get() : nullptr; + ParallelArgs parallel_args(rank, world_size, pg); + workers_.emplace_back(std::make_unique( + parallel_args, devices[i], options_, worker_type)); + } + + if (workers_.size() > 1) { + // test process group + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.emplace_back(worker->process_group_test_async()); + } + // wait up to 4 seconds for all futures to complete + folly::collectAll(futures).within(std::chrono::seconds(4)).get(); + } +} + +bool RecEngine::init() { + if (!init_model()) { + LOG(ERROR) << "Failed to init model from: " << options_.model_path(); + return false; + } + + auto kv_cache_cap = estimate_kv_cache_capacity(); + + if (!allocate_kv_cache(kv_cache_cap)) { + LOG(ERROR) << "Failed to allocate kv cache"; + return false; + } + + return true; +} + +bool RecEngine::init_model() { + const std::string& model_path = options_.model_path(); + auto model_loader = ModelLoader::create(model_path); + LOG(INFO) << "Initializing model from: " << model_path; + + // RecEngine does not use tokenizer + tokenizer_ = model_loader->tokenizer(); + CHECK(tokenizer_ != nullptr); + + args_ = model_loader->model_args(); + quant_args_ = model_loader->quant_args(); + tokenizer_args_ = model_loader->tokenizer_args(); + + // compute the number of local kv heads and head dim + const int world_size = static_cast(workers_.size()); + const int64_t n_heads = args_.n_heads(); + const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads); + n_local_kv_heads_ = std::max(1, n_kv_heads / world_size); + head_dim_ = args_.head_dim(); + dtype_ = xllm::util::parse_dtype(args_.dtype(), options_.devices()[0]); + + // key + value for all layers + LOG(INFO) << "Block info, block_size: " << options_.block_size() + << ", n_local_kv_heads: " << n_local_kv_heads_ + << ", head_dim: " << head_dim_ << ", n_layers: " << args_.n_layers() + << ", dtype: " << dtype_; + + // RecEngine does not use tokenizer, skip vocab_size check + + LOG(INFO) << "Initializing model with " << args_; + LOG(INFO) << "Initializing model with quant args: " << quant_args_; + LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_; + + // init model for each worker in parallel + // multiple workers, call async init + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->init_model_async(model_path, FLAGS_random_seed)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + + return true; +} + +Engine::KVCacheCapacity RecEngine::estimate_kv_cache_capacity() { + const int64_t max_cache_size = options_.max_cache_size(); + const double max_memory_utilization = options_.max_memory_utilization(); + + const auto& device = workers_[0]->device(); + // call worker to profile memory usage + std::vector>> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->estimate_kv_cache_capacity_async()); + } + + // pick smallest available memory from all devices + int64_t cache_size_in_bytes = std::numeric_limits::max(); + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (size_t i = 0; i < results.size(); ++i) { + const auto device = workers_[i]->device(); + if (!results[i].hasValue()) { + LOG(ERROR) << "Failed to profile memory usage for device: " << device; + continue; + } + auto [available_memory, total_memory] = results[i].value(); + LOG(INFO) << device + << ": available memory: " << readable_size(available_memory) + << ", total memory: " << readable_size(total_memory) + << ", Using max_memory_utilization: " << max_memory_utilization + << ", max_cache_size: " << readable_size(max_cache_size); + // apply memory cap from config if it is set + if (max_memory_utilization < 1.0) { + const int64_t buffer_memory = + total_memory * (1.0 - max_memory_utilization); + available_memory -= buffer_memory; + } + if (max_cache_size > 0) { + available_memory = std::min(available_memory, max_cache_size); + } + cache_size_in_bytes = std::min(cache_size_in_bytes, available_memory); + } + + KVCacheCapacity kv_cache_cap; + kv_cache_cap.cache_size_in_bytes = std::max(cache_size_in_bytes, int64_t(0)); + CHECK_GT(kv_cache_cap.cache_size_in_bytes, 0) + << "Available kv cache size must be greater than 0"; + + // compute kv cache slot size + const auto dtype_size = torch::scalarTypeToTypeMeta(dtype_).itemsize(); + // key + value for all layers + const int64_t slot_size = + 2 * n_local_kv_heads_ * head_dim_ * args_.n_layers() * dtype_size; + kv_cache_cap.slot_size = slot_size; + + // compute kv blocks num + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = block_size * slot_size; + kv_cache_cap.n_blocks = cache_size_in_bytes / block_size_in_bytes; + CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; + + return kv_cache_cap; +} + +bool RecEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { + LOG(INFO) << "kv cache capacity: " + << "bytes: " << kv_cache_cap.cache_size_in_bytes + << ", blocks: " << kv_cache_cap.n_blocks + << ", slot_size: " << kv_cache_cap.slot_size; + + const int32_t block_size = options_.block_size(); + + // init kv cache for each worker + std::vector> kv_cache_shape; + kv_cache_shape.reserve(2); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + + LOG(INFO) << "Initializing k cache with shape: [" << kv_cache_shape[0] << "]"; + LOG(INFO) << "Initializing v cache with shape: [" << kv_cache_shape[1] << "]"; + + // initialize block manager + BlockManagerPool::Options options; + options.num_blocks(kv_cache_cap.n_blocks) + .host_num_blocks(kv_cache_cap.n_blocks) + .block_size(block_size) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_cache_upload(options_.enable_cache_upload()); + kv_cache_manager_ = std::make_unique(options); + + // init kv cache for each worker in parallel + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->allocate_kv_cache_async(kv_cache_shape)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + return true; +} + +// RecEngine executes model: prefill + decode steps +// Similar to LLMEngine but simplified for rec model +ForwardOutput RecEngine::step(std::vector& batches) { + if (workers_.empty()) { + // empty worker, return + return {}; + } + + Timer timer; + auto forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, timer.elapsed_microseconds()); + + if (!forward_inputs.token_ids.defined()) { + // empty input, just return + return {}; + } + + timer.reset(); + // Prefill step: Run the first model execution + const auto& prefill_output = get_model_output(forward_inputs); + COUNTER_ADD(rec_first_token_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + // Use process_sample_output from Batch class (same as LLMEngine) + batches[0].process_sample_output(prefill_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds()); + + // Decode steps: Run the model 2 more times for decoding + ForwardOutput decode_output; + + for (int i = 0; i < 2; ++i) { + timer.reset(); + forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + decode_output = get_model_output(forward_inputs); + if (i == 0) { + COUNTER_ADD(rec_second_token_latency_microseconds, + timer.elapsed_microseconds()); + } else if (i == 1) { + COUNTER_ADD(rec_third_token_latency_microseconds, + timer.elapsed_microseconds()); + } + + timer.reset(); + // Use process_sample_output from Batch class (same as LLMEngine) + batches[0].process_sample_output(decode_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, + timer.elapsed_microseconds()); + } + + batches[0].finish(); + + // Return the final model output + return decode_output; +} + +void RecEngine::update_last_step_result(std::vector& batch) { + UNUSED_PARAMETER(batch); +} + +std::vector RecEngine::get_active_activation_memory() const { + // call worker to get active activation memory + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->get_active_activation_memory_async()); + } + + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + std::vector active_activation_memories; + active_activation_memories.reserve(workers_.size()); + for (auto& result : results) { + active_activation_memories.push_back(result.value()); + } + return active_activation_memories; +} + +ForwardOutput RecEngine::get_model_output(const ForwardInput& model_inputs) { + std::vector>> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.emplace_back(worker->step_async(model_inputs)); + } + // wait for the all future to complete + auto results = folly::collectAll(futures).get(); + // return the result from the driver + auto forward_output = results.front().value(); + + CHECK(forward_output.has_value()) << "Failed to execute model"; + return forward_output.value(); +} + +} // namespace xllm diff --git a/xllm/core/distributed_runtime/rec_engine.h b/xllm/core/distributed_runtime/rec_engine.h new file mode 100644 index 000000000..0582b349b --- /dev/null +++ b/xllm/core/distributed_runtime/rec_engine.h @@ -0,0 +1,80 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "common/macros.h" +#include "engine.h" +#include "framework/batch/batch.h" +#include "framework/block/block_manager_pool.h" +#include "framework/quant_args.h" +#include "framework/tokenizer/tokenizer.h" +#include "framework/tokenizer/tokenizer_args.h" +#include "runtime/worker.h" + +namespace xllm { + +class RecEngine : public Engine { + public: + // create an engine with the given devices + RecEngine(const runtime::Options& options); + + virtual ~RecEngine() = default; + + ForwardOutput step(std::vector& batch) override; + + const runtime::Options& options() const { return options_; } + + bool init() override; + + void update_last_step_result(std::vector& batch) override; + + // return the active activation memory + std::vector get_active_activation_memory() const override; + + private: + bool init_model(); + Engine::KVCacheCapacity estimate_kv_cache_capacity(); + bool allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap); + + // Helper methods for rec-specific execution + ForwardOutput get_model_output(const ForwardInput& model_inputs); + + private: + // options + runtime::Options options_; + + // dtype + torch::ScalarType dtype_; + + // quantization args + QuantArgs quant_args_; + + // a list of process groups, with each process group handling a single device + std::vector> process_groups_; + + // a list of workers, with each worker handling a partial of model + std::vector> workers_; + + // config for kv cache + int64_t n_local_kv_heads_ = 0; + int64_t head_dim_ = 0; +}; + +} // namespace xllm diff --git a/xllm/core/distributed_runtime/rec_master.cpp b/xllm/core/distributed_runtime/rec_master.cpp new file mode 100644 index 000000000..0bb10a901 --- /dev/null +++ b/xllm/core/distributed_runtime/rec_master.cpp @@ -0,0 +1,271 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_master.h" + +#include +#include +#include +#include + +#include "common/macros.h" +#include "common/metrics.h" +#include "models/model_registry.h" +#include "rec_engine.h" +#include "runtime/xservice_client.h" +#include "scheduler/scheduler_factory.h" +#include "util/scope_guard.h" +#include "util/threadpool.h" +#include "util/utils.h" + +namespace xllm { + +RecMaster::RecMaster(const Options& options) + : Master(options, EngineType::REC) { + // Initialize with Rec engine type + // The rest of the initialization follows the same pattern as LLMMaster + CHECK(engine_->init()); + + model_args_ = engine_->model_args(); + + bool enable_decode_response_to_service = false; + if (options_.enable_service_routing()) { + XServiceClient* xservice_client = XServiceClient::get_instance(); + if (!xservice_client->init(options_.etcd_addr().value_or(""), + options_.xservice_addr().value_or(""), + options_.instance_name().value_or(""), + engine_->block_manager_pool())) { + LOG(FATAL) << "XServiceClient init fail!"; + return; + } + auto service_config = xservice_client->get_config(); + enable_decode_response_to_service = + service_config.enable_decode_response_to_service; + } + + ContinuousScheduler::Options scheduler_options; + scheduler_options.max_tokens_per_batch(options_.max_tokens_per_batch()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()) + .num_speculative_tokens(options_.num_speculative_tokens()) + .dp_size(options_.dp_size()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .instance_role(options_.instance_role()) + .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) + .enable_service_routing(options_.enable_service_routing()) + .enable_decode_response_to_service(enable_decode_response_to_service); + scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options); + + // OmniRec model does not have a tokenizer + chat_template_ = nullptr; + tokenizer_ = nullptr; + threadpool_ = + std::make_unique(options_.num_request_handling_threads()); +} + +void RecMaster::run() { + const bool already_running = running_.load(std::memory_order_relaxed); + if (already_running) { + LOG(WARNING) << "RecMaster is already running."; + return; + } + running_.store(true, std::memory_order_relaxed); + loop_thread_ = std::thread([this]() { + const auto timeout = absl::Milliseconds(5); + while (!stopped_.load(std::memory_order_relaxed)) { + // move scheduler forward + scheduler_->step(timeout); + } + running_.store(false, std::memory_order_relaxed); + }); + + // Engine run method is not available, remove this call +} + +RecMaster::~RecMaster() { + // set stop flag + stopped_.store(true, std::memory_order_relaxed); + // wait for the loop thread to finish + if (loop_thread_.joinable()) { + loop_thread_.join(); + } +} + +void RecMaster::handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // add one pending request + scheduler_->incr_pending_requests(1); + auto cb = [callback = std::move(callback), + scheduler = scheduler_.get()](const RequestOutput& output) { + output.log_request_status(); + return callback(output); + }; + // add into the queue + threadpool_->schedule([this, + prompt = std::move(prompt), + prompt_tokens = std::move(prompt_tokens), + mm_data = std::move(mm_data), + sp = std::move(sp), + callback = std::move(cb)]() mutable { + AUTO_COUNTER(request_handling_latency_seconds_completion); + + // remove the pending request after scheduling + SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); }); + + Timer timer; + // verify the prompt + if (!sp.verify_params(callback)) { + return; + } + + auto request = generate_request(std::move(prompt), + std::move(prompt_tokens), + std::move(mm_data), + sp, + callback); + if (!request) { + return; + } + + if (!scheduler_->add_request(request)) { + CALLBACK_WITH_ERROR(StatusCode::RESOURCE_EXHAUSTED, + "No available resources to schedule request"); + } + }); +} + +std::shared_ptr RecMaster::generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // For Rec model, prompt is expected to be empty and prompt_tokens should + // contain the actual data Skip prompt empty check as mentioned in + // requirements + + Timer timer; + std::vector local_prompt_tokens; + + if (prompt_tokens.has_value()) { + local_prompt_tokens = std::move(prompt_tokens.value()); + LOG(INFO) + << "[Rec DEBUG] generate_request - received prompt_tokens.size(): " + << local_prompt_tokens.size() + << ", prompt.length(): " << prompt.length(); + } else if (!mm_data.has_value()) { + // sparse LLM + LOG(ERROR) << "Rec model requires prompt_tokens/embedding to be provided"; + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Rec model requires prompt_tokens/embedding to be provided"); + return nullptr; + } + + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + int32_t max_context_len = model_args_.max_position_embeddings(); + if (!options_.enable_chunked_prefill()) { + max_context_len = + std::min(max_context_len, options_.max_tokens_per_batch()); + } + if (local_prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << local_prompt_tokens.size(); + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); + return nullptr; + } + + uint32_t max_tokens = sp.max_tokens; + if (max_tokens == 0) { + const uint32_t kDefaultMaxTokens = 5120; + max_tokens = kDefaultMaxTokens; + } + + // allocate enough capacity for prompt tokens, max tokens, and speculative + // tokens + size_t capacity = local_prompt_tokens.size() + max_tokens + + options_.num_speculative_tokens() + /*bonus_token*/ 1; + if (options_.enable_schedule_overlap()) { + capacity += options_.num_speculative_tokens() + 1; + } + const size_t best_of = sp.best_of.value_or(sp.n); + + RequestSamplingParam sampling_param; + sampling_param.frequency_penalty = sp.frequency_penalty; + sampling_param.presence_penalty = sp.presence_penalty; + sampling_param.repetition_penalty = sp.repetition_penalty; + sampling_param.temperature = sp.temperature; + sampling_param.top_p = sp.top_p; + sampling_param.top_k = sp.top_k; + sampling_param.logprobs = sp.logprobs; + sampling_param.top_logprobs = sp.top_logprobs; + sampling_param.is_embeddings = sp.is_embeddings; + sampling_param.beam_width = sp.beam_width; + if (best_of > sp.n) { + // enable logprobs for best_of to generate sequence logprob + sampling_param.logprobs = true; + } + // sampling_param.do_sample = sp.do_sample; + + bool stream = sp.streaming; + // results cannot be streamed when best_of != n + if (best_of != sp.n) { + stream = false; + } + // std::unordered_set stop_tokens; + // std::vector> stop_sequences; + // StoppingChecker stopping_checker( + // max_tokens, + // max_context_len - options_.num_speculative_tokens(), + // , + // model_args_.eos_token_id(), + // sp.ignore_eos, + // std::move(stop_tokens), + // std::move(stop_sequences)); + StoppingChecker stopping_checker; + RequestState req_state(std::move(prompt), + std::move(local_prompt_tokens), + mm_data.value_or(MMData{}), + std::move(sampling_param), + std::move(stopping_checker), + capacity, + sp.n, + best_of, + sp.logprobs, + stream, + sp.echo, + sp.skip_special_tokens, + options_.enable_schedule_overlap(), + callback, + nullptr, + sp.decode_address); + // TODO. add following when next pr (add is_rec_model and bos_token_id to + // RequestState). req_state.is_rec_model = true; req_state.bos_token_id = + // model_args_.bos_token_id(); + auto request = std::make_shared(sp.request_id, + sp.x_request_id, + sp.x_request_time, + std::move(req_state), + sp.service_request_id); + return request; +} + +} // namespace xllm diff --git a/xllm/core/distributed_runtime/rec_master.h b/xllm/core/distributed_runtime/rec_master.h new file mode 100644 index 000000000..0ed5b76d3 --- /dev/null +++ b/xllm/core/distributed_runtime/rec_master.h @@ -0,0 +1,71 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "framework/chat_template/jinja_chat_template.h" +#include "framework/model/model_args.h" +#include "master.h" +#include "rec_engine.h" +#include "scheduler/continuous_scheduler.h" +#include "scheduler/fixed_steps_scheduler.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecMaster : public Master { + public: + explicit RecMaster(const Options& options); + ~RecMaster(); + + // handle a request, the engine will execute the request asynchronously + // completion/encode + void handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + // start the handling loop + void run() override; + + private: + std::shared_ptr generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + std::unique_ptr scheduler_; + // model args + ModelArgs model_args_; + std::unique_ptr threadpool_; + std::unique_ptr tokenizer_; + // chat template instance + std::unique_ptr chat_template_; + // thread for moving forward the scheduler + std::thread loop_thread_; + // flag to stop the loop + std::atomic stopped_{false}; + + // flag to indicate if the handler is running + std::atomic running_{false}; +}; + +} // namespace xllm diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index f33ac9bde..5699cc818 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -453,4 +453,11 @@ void Batch::process_beam_search_output(const RawForwardOutput& raw_output, update_for_sequence_group(sequence_group_id); } } + +void Batch::finish() { + // Finish all sequence groups + for (auto* sequence_group : sequence_groups_) { + sequence_group->finish(); + } +} } // namespace xllm diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index 798332736..31c409f82 100755 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -110,6 +110,10 @@ class Batch { // process the accepted output embedding void process_embedding_output(const torch::Tensor& embedding); + // mark all sequence groups as finished (used by rec model multi-round + // decoding) + void finish(); + const std::vector& get_allowed_max_tokens() const { return allowed_max_tokens_; } diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 81c0f4e7a..26578aa4b 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -489,4 +489,12 @@ bool Sequence::update_prefetch_result(uint32_t timeout) { return true; } +void Sequence::finish() { + finished_ = true; + finish_status_invalidated_ = false; + if (finish_reason_ == FinishReason::NONE) { + finish_reason_ = FinishReason::STOP; + } +} + } // namespace xllm diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index d2f8c0d41..ee9b3f210 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -181,6 +181,8 @@ class Sequence final { FinishReason finish_reason() const { return finish_reason_; } // check finish status, use cached value if not invalidated bool finished() const; + // mark sequence as finished (used by rec model multi-round decoding) + void finish(); // get the output of the sequence until the specified number of tokens, // returns nullopt if no delta text and not finished diff --git a/xllm/core/framework/request/sequences_group.cpp b/xllm/core/framework/request/sequences_group.cpp index 7bbce9afd..c7760c0c6 100644 --- a/xllm/core/framework/request/sequences_group.cpp +++ b/xllm/core/framework/request/sequences_group.cpp @@ -292,4 +292,10 @@ void SequencesGroup::process_beam_search() { update_for_sequence(0, beam_width); } +void SequencesGroup::finish() { + for (auto& sequence : sequences_) { + sequence->finish(); + } +} + } // namespace xllm diff --git a/xllm/core/framework/request/sequences_group.h b/xllm/core/framework/request/sequences_group.h index ef3764bc0..59b0f3604 100644 --- a/xllm/core/framework/request/sequences_group.h +++ b/xllm/core/framework/request/sequences_group.h @@ -62,6 +62,9 @@ class SequencesGroup { return sequences_[0]->is_chunked_prefill_stage(); } + // mark all sequences as finished (used by rec model multi-round decoding) + void finish(); + private: void add(); diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index a58c6bbee..3c9c2bcf4 100755 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -38,6 +38,7 @@ class WorkerType { DIT, // DIT ELM, // Embedding LM EVLM, // Embedding VLM + REC, // Rec }; constexpr WorkerType(Value v) : value_(v) {} @@ -52,6 +53,8 @@ class WorkerType { value_ = ELM; } else if (str == "EVLM") { value_ = EVLM; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -78,6 +81,8 @@ class WorkerType { return "ELM"; } else if (this->value_ == EVLM) { return "EVLM"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index d4d0124b1..214c38f52 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -51,6 +51,11 @@ Worker::Worker(const ParallelArgs& parallel_args, impl_ = new EmbedWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::EVLM) { impl_ = new EmbedVLMWorkerImpl(parallel_args, device, options); + } else if (worker_type == WorkerType::REC) { + // TODO. add following when next pr (use RecWorkerImpl). + // impl_ = new RecWorkerImpl(parallel_args, device, options); + // TODO. delete this when next pr. + impl_ = new LLMWorkerImpl(parallel_args, device, options); } else { LOG(ERROR) << "Unknown worker type, please check logic"; } diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index d694b3b1f..8acf272b4 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -20,6 +20,7 @@ cc_library( scheduler_factory.h decode_priority_queue.h perf_model.h + fixed_steps_scheduler.h SRCS chunked_prefill_scheduler.cpp zero_eviction_scheduler.cpp @@ -31,6 +32,7 @@ cc_library( prefill_only_scheduler.cpp scheduler_factory.cpp perf_model.cpp + fixed_steps_scheduler.cpp DEPS :batch :request @@ -54,4 +56,3 @@ cc_test( GTest::gtest_main $<$:nnopbase> ) - diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index 59af7483b..8a37df3ac 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "common/global_flags.h" #include "common/metrics.h" #include "distributed_runtime/engine.h" #include "framework/batch/batch_factory.h" @@ -37,14 +38,11 @@ limitations under the License. #include "util/utils.h" namespace xllm { -namespace { -constexpr size_t kRequestQueueSize = 100000; -} // namespace ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options) : options_(options), engine_(engine), - request_queue_(kRequestQueueSize), + request_queue_(FLAGS_request_queue_size), waiting_priority_queue_(create_comparator(options.priority_strategy())), waiting_priority_queue_offline_( create_comparator(options.priority_strategy())) { diff --git a/xllm/core/scheduler/continuous_scheduler.h b/xllm/core/scheduler/continuous_scheduler.h index f3b840156..a410e6e16 100644 --- a/xllm/core/scheduler/continuous_scheduler.h +++ b/xllm/core/scheduler/continuous_scheduler.h @@ -195,7 +195,7 @@ class ContinuousScheduler : public Scheduler { KVCacheManager* kv_cache_manager_; - // a thread safe queue of requests, bounded by kRequestQueueSize + // a thread safe queue of requests, bounded by FLAGS_request_queue_size // the schedule owns the requests and manages their lifetimes. folly::MPMCQueue> request_queue_; diff --git a/xllm/core/scheduler/fixed_steps_scheduler.cpp b/xllm/core/scheduler/fixed_steps_scheduler.cpp new file mode 100644 index 000000000..0d0d51411 --- /dev/null +++ b/xllm/core/scheduler/fixed_steps_scheduler.cpp @@ -0,0 +1,311 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "fixed_steps_scheduler.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/metrics.h" +#include "distributed_runtime/engine.h" +#include "framework/batch/batch.h" +#include "framework/batch/batch_factory.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" + +namespace xllm { + +FixedStepsScheduler::FixedStepsScheduler(Engine* engine, const Options& options) + : ContinuousScheduler(engine, options) {} + +bool FixedStepsScheduler::add_request(std::shared_ptr& request) { + CHECK(request != nullptr); + CHECK(!request->sequences().empty()); + + if (request_queue_.write(request)) { //.get() + // take over the ownership of the request + // request.release(); + return true; + } + // queue is full + return false; +} + +void FixedStepsScheduler::handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests) { + // Handle new request prompt first. + // Include those requests that are preempted by others. + // + // schedule the prefill requests in the waiting priority queue until budgets + // are exhausted. + // When the KV Cache usage reaches the threshold, prefill requests will no + // longer be scheduled to avoid frequent preemption. + // + // NOTE: preempted requests will be pushed in waiting_priority_queue, + // they may contian many sequences, so we should check here. + bool budget_exhausted = false; + bool blocks_exhausted = false; + while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && + remaining_token_budget > 0 && + kv_cache_manager_->kv_cache_utilization() < + FLAGS_prefill_scheduling_memory_usage_threshold) { + std::shared_ptr request(waiting_priority_queue_.top()); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // remove the request from the priority queue + waiting_priority_queue_.pop(); + continue; + } + + const size_t num_sequences = request->sequences().size(); + if (!request->preempted()) { + CHECK(num_sequences == 1) + << "Waiting request should have only one sequence."; + } + + // TODO: FIXME later + // Optimization of the scheduling algorithm under multiple sequences + size_t allocated_tokens = 0; + size_t allocated_seqs = 0; + double allocated_estimate_latency = 0; + bool can_schedule = true; + std::vector prefill_sequences; + std::vector prefill_sequences_budget; + prefill_sequences.reserve(request->sequences().size()); + prefill_sequences_budget.reserve(request->sequences().size()); + for (auto& prefill_sequence : request->sequences()) { + if (prefill_sequence->finished()) { + continue; + } + + size_t num_tokens = prefill_sequence->num_need_compute_tokens(); + if (remaining_token_budget < allocated_tokens + num_tokens || + remaining_seq_budget < allocated_seqs + 1) { + can_schedule = false; + budget_exhausted = true; + break; + } + + prefill_sequences_budget.emplace_back(num_tokens); + prefill_sequences.emplace_back(prefill_sequence.get()); + allocated_tokens += num_tokens; + allocated_seqs += 1; + } + + if (!can_schedule) { + for (auto& seq : prefill_sequences) { + // release shared blocks + kv_cache_manager_->deallocate(seq); + } + break; + } + + if (prefill_sequences.empty()) { + continue; + } + + remaining_token_budget -= allocated_tokens; + remaining_seq_budget -= allocated_seqs; + waiting_priority_queue_.pop(); + running_requests_.emplace_back(request); + running_sequences_.insert(running_sequences_.end(), + prefill_sequences.begin(), + prefill_sequences.end()); + running_sequences_budgets_.insert(running_sequences_budgets_.end(), + prefill_sequences_budget.begin(), + prefill_sequences_budget.end()); + } + + if (running_sequences_.empty() && !waiting_priority_queue_.empty() && + running_queue_->empty()) { + LOG(ERROR) + << "Request prompt is too long, no enough budget/memory to schedule " + "a single sequence."; + // no enough memory to schedule single sequence, just finish the request + std::shared_ptr request(waiting_priority_queue_.top()); + waiting_priority_queue_.pop(); + // block_manager_->release_blocks_for(request.get()); + response_processor_->process_failed_request( + request, + {StatusCode::RESOURCE_EXHAUSTED, + "No enough budget to schedule single sequence."}); + } +} + +std::vector FixedStepsScheduler::prepare_batch() { + Timer timer; + // propogate new requests to waiting_priority_queue_ + // Include those requests that are preempted by others. + std::shared_ptr request; + // read from request queue then push to waiting priority queue + while (request_queue_.read(request)) { + CHECK(request); + + // expand sequences to the target number if prefix cache is disabled. + if (!enable_prefix_cache_) { + // expand sequences to the target number + request->expand_sequences(false); + } + + if (request->sequences()[0]->kv_state().kv_cache_tokens_num() == 0) { + waiting_priority_queue_.push(request); + } else { + // request from prefill instance in disagge pd mode. + running_requests_.emplace_back(request); + } + } + + // handle finished/cancelled requests + std::vector> finished_requests; + for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + ++it) { + if (*it == nullptr) { + continue; + } + std::shared_ptr request = *it; + request->update_connection_status(); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // finished request is set to nullptr + *it = nullptr; + } + } + + // clear previous batch + running_requests_.clear(); + running_sequences_.clear(); + running_sequences_budgets_.clear(); + + // remaining budget for the current batch + size_t remaining_token_budget = options_.max_tokens_per_batch(); + size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); + size_t num_preempted_requests = 0; + + handle_prefill_requests( + remaining_token_budget, remaining_seq_budget, finished_requests); + + // only forward once, no decode requests + // handle_decode_requests( + // remaining_token_budget, remaining_seq_budget, num_preempted_requests); + + if (!finished_requests.empty()) { + response_processor_->process_completed_requests(finished_requests); + } + + // update the batch + // TODO. add following when next pr (use create_rec_batches). + // auto batches = BatchFactory::get_instance(options_.dp_size()) + // ->create_rec_batches( + // running_requests_, + // running_sequences_, + // running_sequences_budgets_, + // kv_cache_manager_->get_swap_block_transfer_infos()); + // TODO. delete this when next pr. + auto batches = + BatchFactory::get_instance(options_.dp_size()) + ->create_batches(running_requests_, + running_sequences_, + running_sequences_budgets_, + kv_cache_manager_->get_swap_block_transfer_infos()); + + // update metrics before returning + if (!batches[0].empty()) { + // only update the scheduling latency when there are requests to process + COUNTER_ADD(scheduling_latency_seconds, timer.elapsed_seconds()); + } + + GAUGE_SET(num_pending_requests, + pending_requests_.load(std::memory_order_relaxed)); + GAUGE_SET(num_running_requests, running_requests_.size()); + GAUGE_SET(num_waiting_requests, + waiting_priority_queue_.size() + running_queue_->size()); + + GAUGE_ADD(num_preempted_requests, num_preempted_requests); + + GAUGE_SET(num_running_sequences, running_sequences_.size()); + + GAUGE_SET(kv_cache_utilization_perc, + kv_cache_manager_->kv_cache_utilization()); + if (!FLAGS_enable_continuous_kvcache) { + GAUGE_SET(num_blocks_in_prefix_cache, + kv_cache_manager_->num_blocks_in_prefix_cache().size()); + GAUGE_SET(num_free_blocks, kv_cache_manager_->num_free_blocks().size()); + GAUGE_SET(num_used_blocks, kv_cache_manager_->num_used_blocks().size()); + } + return batches; +} + +std::vector FixedStepsScheduler::schedule_request( + const absl::Duration& timeout) { + const auto deadline = absl::Now() + timeout; + std::vector batch; + while (true) { + batch = prepare_batch(); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (!all_empty) { + return batch; + } + const auto now = absl::Now(); + if (now > deadline) { + break; + } + // wait for new requests to arrive + constexpr uint64_t kStepSleepTimeMs = 1; + const auto time_to_sleep = + std::min(absl::Milliseconds(kStepSleepTimeMs), deadline - now); + absl::SleepFor(time_to_sleep); + } + // return an empty batch + return batch; +} + +// step the scheduler forward by one step +// may get blocked if there are no requests to process +void FixedStepsScheduler::step(const absl::Duration& timeout) { + if (!options_.enable_schedule_overlap()) { + // get a new batch of requests + std::vector batch = schedule_request(timeout); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (all_empty) { + return; + } + engine_->step(batch); + kv_cache_manager_->reset_transfer_infos(); + } else { + LOG(ERROR) << "FixedStepsScheduler::step() not supported with " + "enable_schedule_overlap"; + } +} + +} // namespace xllm diff --git a/xllm/core/scheduler/fixed_steps_scheduler.h b/xllm/core/scheduler/fixed_steps_scheduler.h new file mode 100644 index 000000000..ab4d6c5de --- /dev/null +++ b/xllm/core/scheduler/fixed_steps_scheduler.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "async_response_processor.h" +#include "common/macros.h" +#include "common/types.h" +#include "framework/batch/batch.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/xservice_client.h" +#include "scheduler.h" +#include "scheduler/continuous_scheduler.h" + +namespace xllm { +class Engine; + +class FixedStepsScheduler final : public ContinuousScheduler { + public: + FixedStepsScheduler(Engine* engine, const Options& options); + virtual ~FixedStepsScheduler() = default; + + bool add_request(std::shared_ptr& request) override; + + // step the scheduler forward by one step + // may get blocked if there are no requests to process + void step(const absl::Duration& timeout) override; + + private: + std::vector schedule_request(const absl::Duration& timeout); + + // build a batch of requests from the priority queue + virtual std::vector prepare_batch(); + + void handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests); +}; + +} // namespace xllm diff --git a/xllm/core/scheduler/scheduler_factory.cpp b/xllm/core/scheduler/scheduler_factory.cpp index 8be5a8b84..b257c61a8 100644 --- a/xllm/core/scheduler/scheduler_factory.cpp +++ b/xllm/core/scheduler/scheduler_factory.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "scheduler/continuous_scheduler.h" #include "scheduler/disagg_pd_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixed_steps_scheduler.h" #include "scheduler/pd_ooc_scheduler.h" #include "scheduler/prefill_only_scheduler.h" #include "scheduler/zero_eviction_scheduler.h" @@ -57,4 +58,10 @@ std::unique_ptr create_dit_scheduler( return std::make_unique(engine, options); } +std::unique_ptr create_fixed_steps_scheduler( + Engine* engine, + ContinuousScheduler::Options options) { + return std::make_unique(engine, options); +} + } // namespace xllm diff --git a/xllm/core/scheduler/scheduler_factory.h b/xllm/core/scheduler/scheduler_factory.h index daf153bad..e28e74db3 100644 --- a/xllm/core/scheduler/scheduler_factory.h +++ b/xllm/core/scheduler/scheduler_factory.h @@ -18,6 +18,7 @@ limitations under the License. #include "runtime/xservice_client.h" #include "scheduler/continuous_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixed_steps_scheduler.h" namespace xllm { @@ -29,4 +30,8 @@ std::unique_ptr create_dit_scheduler( DiTEngine* engine, DiTScheduler::Options options); +std::unique_ptr create_fixed_steps_scheduler( + Engine* engine, + ContinuousScheduler::Options options); + } // namespace xllm