Skip to content

Commit 5de436f

Browse files
DragonFivemaxiaolong.maxwell
authored andcommitted
feat: add rec_type and onerec batch input builder.
1 parent 7f93e56 commit 5de436f

20 files changed

+1652
-42
lines changed

xllm/api_service/rec_completion_service_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ limitations under the License.
1919

2020
#include "api_service_impl.h"
2121
#include "completion.pb.h"
22-
#include "rec.pb.h"
2322
#include "core/distributed_runtime/rec_master.h"
23+
#include "rec.pb.h"
2424
#include "stream_call.h"
2525

2626
namespace xllm {

xllm/core/distributed_runtime/rec_master.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ limitations under the License.
1515

1616
#include "rec_master.h"
1717

18+
#include <absl/time/time.h>
1819
#include <gflags/gflags.h>
1920
#include <glog/logging.h>
2021
#include <pybind11/pybind11.h>
2122

22-
#include <absl/time/time.h>
23-
2423
#include "common/macros.h"
2524
#include "common/metrics.h"
2625
#include "models/model_registry.h"
@@ -33,13 +32,32 @@ limitations under the License.
3332

3433
namespace xllm {
3534

35+
namespace {
36+
37+
RecType InferRecTypeFromModelArgs(const ModelArgs& model_args) {
38+
const auto& model_type = model_args.model_type();
39+
if (model_type == "onerec") {
40+
return RecType::kOneRec;
41+
}
42+
if (model_type == "qwen3rec") {
43+
return RecType::kLlmRec;
44+
}
45+
return RecType::kNone;
46+
}
47+
48+
} // namespace
49+
3650
RecMaster::RecMaster(const Options& options)
3751
: Master(options, EngineType::REC) {
3852
// Initialize with Rec engine type
3953
// The rest of the initialization follows the same pattern as LLMMaster
4054
CHECK(engine_->init());
4155

4256
model_args_ = engine_->model_args();
57+
rec_type_ = InferRecTypeFromModelArgs(model_args_);
58+
if (rec_type_ == RecType::kNone) {
59+
LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type();
60+
}
4361

4462
bool enable_decode_response_to_service = false;
4563
if (options_.enable_service_routing()) {
@@ -72,7 +90,6 @@ RecMaster::RecMaster(const Options& options)
7290
.enable_decode_response_to_service(enable_decode_response_to_service);
7391
scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options);
7492

75-
// OmniRec model does not have a tokenizer
7693
chat_template_ = nullptr;
7794
tokenizer_ = nullptr;
7895
threadpool_ =
@@ -163,6 +180,26 @@ std::shared_ptr<Request> RecMaster::generate_request(
163180
// contain the actual data Skip prompt empty check as mentioned in
164181
// requirements
165182

183+
if (rec_type_ == RecType::kNone) {
184+
LOG(ERROR) << "Unsupported rec model_type: " << model_args_.model_type();
185+
CALLBACK_WITH_ERROR(
186+
StatusCode::INVALID_ARGUMENT,
187+
std::string("Unsupported rec model_type: ") + model_args_.model_type());
188+
return nullptr;
189+
}
190+
191+
// qwen3rec requires tokenizer and RecBatchInputBuilder support. This PR keeps
192+
// the extension point but rejects requests early to avoid LOG(FATAL) in the
193+
// batch builder path.
194+
if (rec_type_ == RecType::kLlmRec) {
195+
LOG(ERROR) << "Rec model_type is not supported yet: "
196+
<< model_args_.model_type();
197+
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
198+
std::string("Rec model_type is not supported yet: ") +
199+
model_args_.model_type());
200+
return nullptr;
201+
}
202+
166203
Timer timer;
167204
std::vector<int> local_prompt_tokens;
168205

@@ -258,9 +295,8 @@ std::shared_ptr<Request> RecMaster::generate_request(
258295
callback,
259296
nullptr,
260297
sp.decode_address);
261-
// TODO. add following when next pr (add is_rec_model and bos_token_id to
262-
// RequestState). req_state.is_rec_model = true; req_state.bos_token_id =
263-
// model_args_.bos_token_id();
298+
req_state.rec_type = rec_type_;
299+
req_state.bos_token_id = model_args_.bos_token_id();
264300
auto request = std::make_shared<Request>(sp.request_id,
265301
sp.x_request_id,
266302
sp.x_request_time,

xllm/core/distributed_runtime/rec_master.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "framework/chat_template/jinja_chat_template.h"
2222
#include "framework/model/model_args.h"
23+
#include "framework/request/rec_type.h"
2324
#include "master.h"
2425
#include "rec_engine.h"
2526
#include "scheduler/continuous_scheduler.h"
@@ -55,6 +56,7 @@ class RecMaster : public Master {
5556
std::unique_ptr<FixedStepsScheduler> scheduler_;
5657
// model args
5758
ModelArgs model_args_;
59+
RecType rec_type_ = RecType::kNone;
5860
std::unique_ptr<ThreadPool> threadpool_;
5961
std::unique_ptr<Tokenizer> tokenizer_;
6062
// chat template instance

xllm/core/framework/batch/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@ cc_library(
1010
batch.h
1111
batch_factory.h
1212
batch_input_builder.h
13+
rec_batch_input_builder.h
14+
onerec_batch_input_builder.h
1315
mposition.h
1416
SRCS
1517
dit_batch.cpp
1618
batch.cpp
1719
batch_factory.cpp
1820
batch_input_builder.cpp
21+
rec_batch_input_builder.cpp
22+
onerec_batch_input_builder.cpp
1923
mposition.cpp
2024
beam_search.h
2125
DEPS

xllm/core/framework/batch/batch.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
#include "framework/model/model_input_params.h"
3030
#include "framework/request/sequence.h"
3131
#include "framework/sampling/sampling_params.h"
32+
#include "rec_batch_input_builder.h"
3233
#include "runtime/params_utils.h"
3334
#include "util/slice.h"
3435
#include "util/tensor_helper.h"
@@ -96,6 +97,10 @@ void Batch::add(const std::vector<Sequence*>& sequences) {
9697
ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
9798
uint32_t min_decoding_batch_size,
9899
const ModelArgs& args) {
100+
if (sequences_.empty() && !sequence_groups_.empty()) {
101+
return prepare_rec_forward_input(
102+
num_decoding_tokens, min_decoding_batch_size, args);
103+
}
99104
BatchInputBuilder builder(sequences_,
100105
allowed_max_tokens_,
101106
input_embeddings_vec_,
@@ -108,6 +113,43 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
108113
min_decoding_batch_size);
109114
}
110115

116+
ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens,
117+
uint32_t min_decoding_batch_size,
118+
const ModelArgs& args,
119+
ThreadPool* thread_pool) {
120+
RecType rec_type = RecType::kNone;
121+
if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) {
122+
rec_type = sequence_groups_[0]->sequences()[0]->rec_type();
123+
}
124+
125+
auto builder = RecBatchInputBuilder::Create(rec_type,
126+
sequence_groups_,
127+
allowed_max_tokens_,
128+
input_embeddings_vec_,
129+
mm_data_vec_,
130+
swap_block_transfer_infos_,
131+
batch_id_,
132+
&args,
133+
thread_pool);
134+
return builder->build_rec_forward_input(num_decoding_tokens,
135+
min_decoding_batch_size);
136+
}
137+
138+
std::vector<Sequence*> Batch::get_sequences() const {
139+
if (!sequences_.empty()) {
140+
return sequences_;
141+
}
142+
143+
std::vector<Sequence*> result;
144+
for (const auto* seq_group : sequence_groups_) {
145+
const auto& sequences = seq_group->sequences();
146+
for (const auto& seq_ptr : sequences) {
147+
result.push_back(seq_ptr.get());
148+
}
149+
}
150+
return result;
151+
}
152+
111153
void Batch::dp_balance_shuffle_seqs() {
112154
// this shuffle operation is mainly used for npu with 24 cores
113155
// and specific mla op implementation
@@ -217,7 +259,8 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output,
217259
// this means all sequences are in prefill stage status.
218260
const int64_t num_seqs = raw_output.outputs.size();
219261
int64_t output_idx = 0;
220-
for (auto* seq : sequences_) {
262+
const auto sequences = get_sequences();
263+
for (auto* seq : sequences) {
221264
if (seq->finished()) {
222265
output_idx++;
223266
continue;
@@ -264,7 +307,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
264307
if (sample_output.embeddings.defined()) {
265308
const int64_t num_seqs = sample_output.embeddings.size(0);
266309
int64_t output_idx = 0;
267-
for (auto* seq : sequences_) {
310+
const auto sequences = get_sequences();
311+
for (auto* seq : sequences) {
268312
CHECK_LT(output_idx, num_seqs);
269313
auto cur_seq_embed =
270314
safe_to(sample_output.embeddings[output_idx++], torch::kFloat32);
@@ -277,7 +321,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
277321
// this means all sequences are in prefill stage status.
278322
const int64_t num_seqs = sample_output.next_tokens.size(0);
279323
int64_t output_idx = 0;
280-
for (auto* seq : sequences_) {
324+
const auto sequences = get_sequences();
325+
for (auto* seq : sequences) {
281326
if (seq->finished()) {
282327
output_idx++;
283328
continue;
@@ -352,7 +397,8 @@ void Batch::process_embedding_output(const torch::Tensor& output_embedding) {
352397
Token token(0);
353398
if (output_embedding.defined()) {
354399
int32_t slice_img_index = 0;
355-
for (auto* seq : sequences_) { // TODO
400+
const auto sequences = get_sequences();
401+
for (auto* seq : sequences) {
356402
const auto& mm_data = seq->get_mm_data();
357403

358404
auto pixel_values = mm_data.get_tensor_vec("pixel_values");

xllm/core/framework/batch/batch.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Batch {
7575

7676
// get the number of sequences in the batch
7777
size_t size() const { return sequences_.size(); }
78-
bool empty() const { return sequences_.empty(); }
78+
bool empty() const { return sequences_.empty() && sequence_groups_.empty(); }
7979

8080
Sequence* operator[](size_t i) { return sequences_[i]; }
8181

@@ -84,6 +84,11 @@ class Batch {
8484
uint32_t min_decoding_bach_size,
8585
const ModelArgs& args);
8686

87+
ForwardInput prepare_rec_forward_input(uint32_t num_decoding_tokens,
88+
uint32_t min_decoding_batch_size,
89+
const ModelArgs& args,
90+
ThreadPool* thread_pool = nullptr);
91+
8792
// Convert Batch to pb type, which will be pass to remote worker.
8893
RawForwardInput prepare_forward_input(const ModelArgs& args,
8994
ThreadPool* thread_pool);
@@ -110,7 +115,8 @@ class Batch {
110115
// process the accepted output embedding
111116
void process_embedding_output(const torch::Tensor& embedding);
112117

113-
// mark all sequence groups as finished (used by rec model multi-round decoding)
118+
// mark all sequence groups as finished (used by rec model multi-round
119+
// decoding)
114120
void finish();
115121

116122
const std::vector<uint32_t>& get_allowed_max_tokens() const {
@@ -137,6 +143,8 @@ class Batch {
137143

138144
void dp_balance_shuffle_seqs();
139145

146+
std::vector<Sequence*> get_sequences() const;
147+
140148
std::vector<Sequence*> sequences_;
141149
std::vector<SequencesGroup*> sequence_groups_;
142150
std::vector<BlockTransferInfo>* swap_block_transfer_infos_ = nullptr;

xllm/core/framework/batch/batch_factory.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,62 @@ std::vector<Batch> BatchFactory::create_batches(
9292
return batches;
9393
}
9494

95+
std::vector<Batch> BatchFactory::create_rec_batches(
96+
const std::vector<std::shared_ptr<Request>>& running_requests,
97+
const std::vector<Sequence*>& running_sequences,
98+
const std::vector<size_t>& running_sequences_budgets,
99+
std::vector<std::vector<BlockTransferInfo>>* swap_block_transfer_infos) {
100+
size_t num_prompt_tokens = 0;
101+
size_t num_generated_tokens = 0;
102+
std::vector<Batch> batches(dp_size_);
103+
for (size_t i = 0; i < running_sequences.size(); ++i) {
104+
auto* sequence = running_sequences[i];
105+
const size_t token_budget = running_sequences_budgets[i];
106+
107+
const size_t remaining_prompt_tokens =
108+
sequence->num_prompt_tokens() >
109+
sequence->kv_state().kv_cache_tokens_num()
110+
? sequence->num_prompt_tokens() -
111+
sequence->kv_state().kv_cache_tokens_num()
112+
: 0;
113+
const size_t prompt_tokens =
114+
std::min(remaining_prompt_tokens, token_budget);
115+
const size_t generated_tokens = token_budget - prompt_tokens;
116+
num_prompt_tokens += prompt_tokens;
117+
num_generated_tokens += generated_tokens;
118+
119+
batches[sequence->dp_rank()].set_batch_id();
120+
}
121+
122+
for (const auto& request : running_requests) {
123+
auto seq_group = request->sequence_group();
124+
int32_t dp_rank = seq_group->dp_rank();
125+
batches[dp_rank].add(seq_group);
126+
}
127+
128+
for (int i = 0; i < dp_size_; i++) {
129+
if (!batches[i].empty()) {
130+
if (swap_block_transfer_infos != nullptr &&
131+
swap_block_transfer_infos->size() == dp_size_) {
132+
batches[i].set_swap_block_transfer_infos(
133+
&(swap_block_transfer_infos->at(i)));
134+
}
135+
}
136+
}
137+
138+
COUNTER_ADD(num_processing_tokens_total_prompt, num_prompt_tokens);
139+
COUNTER_ADD(num_processing_tokens_total_generated, num_generated_tokens);
140+
141+
if (running_sequences.size() > 0) {
142+
HISTOGRAM_OBSERVE(
143+
num_prompt_tokens_per_request,
144+
static_cast<int64_t>(num_prompt_tokens / running_sequences.size()));
145+
HISTOGRAM_OBSERVE(
146+
num_generated_tokens_per_request,
147+
static_cast<int64_t>(num_generated_tokens / running_sequences.size()));
148+
}
149+
150+
return batches;
151+
}
152+
95153
} // namespace xllm

xllm/core/framework/batch/batch_factory.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ class BatchFactory {
3535
std::vector<std::vector<BlockTransferInfo>>* swap_block_transfer_infos =
3636
nullptr);
3737

38+
std::vector<Batch> create_rec_batches(
39+
const std::vector<std::shared_ptr<Request>>& running_requests,
40+
const std::vector<Sequence*>& running_sequences,
41+
const std::vector<size_t>& running_sequences_budgets,
42+
std::vector<std::vector<BlockTransferInfo>>* swap_block_transfer_infos =
43+
nullptr);
44+
3845
private:
3946
BatchFactory(int32_t dp_size) : dp_size_(dp_size) {}
4047
~BatchFactory() = default;

0 commit comments

Comments
 (0)