Skip to content

Commit 2dd568e

Browse files
feat: support deepseek mtp on mlu device. (#454)
Co-authored-by: phantomlei <phantomlei3@gmail.com>
1 parent 6b246e2 commit 2dd568e

File tree

10 files changed

+488
-40
lines changed

10 files changed

+488
-40
lines changed

xllm/core/framework/model/model_args.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ struct ModelArgs {
124124
PROPERTY(int32_t, v_head_dim) = 0;
125125
PROPERTY(int32_t, q_lora_rank) = 0;
126126
PROPERTY(int32_t, kv_lora_rank) = 0;
127+
// deepseek v3/v3.2 MTP
128+
PROPERTY(int32_t, num_nextn_predict_layers) = 0;
127129

128130
// deepseek v3.2 indexer
129131
PROPERTY(int32_t, index_head_dim) = 0;

xllm/core/framework/model/model_input_params.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,17 @@ struct ModelInputParams {
161161
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
162162
<< dp_global_token_nums;
163163
}
164+
165+
int32_t get_q_seq_len(int32_t seq_idx) const {
166+
#if defined(USE_NPU)
167+
CHECK(seq_idx < q_seq_lens_vec.size()) << "seq_idx out of range";
168+
return q_seq_lens_vec[seq_idx];
169+
#else
170+
CHECK(seq_idx < q_seq_lens_vec.size() - 1) << "seq_idx out of range";
171+
return q_seq_lens_vec[seq_idx + 1] - q_seq_lens_vec[seq_idx];
172+
#endif
173+
}
174+
164175
// whether the kv-cache is empty for all sequences.
165176
bool empty_kv_cache = true;
166177
BatchForwardType batch_forward_type;

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 154 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,98 @@ int32_t kv_cache_slot_id(int32_t position,
5858
return block_id * block_size + block_offset;
5959
}
6060

61+
// Convert tensor to int64 for MLU platform (temp workaround)
62+
// MLU will support int32 for masked_scatter in the future
63+
torch::Tensor ensure_int64_for_certain_platform(torch::Tensor tensor) {
64+
#if defined(USE_MLU)
65+
return tensor.to(torch::kInt64);
66+
#else
67+
return tensor;
68+
#endif
69+
}
70+
71+
// Push cumulative sum to vector (used for cumulative format)
72+
void push_cumsum(std::vector<int32_t>& vec, int32_t len) {
73+
if (vec.empty()) {
74+
vec.emplace_back(0);
75+
}
76+
vec.emplace_back(vec.back() + len);
77+
}
78+
79+
// Calculate actual kv_len based on platform type
80+
// For NPU: direct format - returns kv_seq_lens_slice[seq_id] + offset
81+
// For MLU/CUDA: cumulative format - returns the actual length increment
82+
int32_t calculate_kv_len(const Slice<int32_t>& kv_seq_lens_slice,
83+
int32_t seq_id,
84+
int32_t offset) {
85+
#if defined(USE_NPU)
86+
return kv_seq_lens_slice[seq_id] + offset;
87+
#elif defined(USE_MLU) || defined(USE_CUDA)
88+
return kv_seq_lens_slice[seq_id + 1] - kv_seq_lens_slice[seq_id] + offset;
89+
#endif
90+
}
91+
92+
// Append sequence length to vector based on platform type
93+
// For NPU: directly add the len value
94+
// For MLU/CUDA: add using cumulative format
95+
void append_seq_len(std::vector<int32_t>& vec, int32_t len) {
96+
#if defined(USE_NPU)
97+
vec.emplace_back(len);
98+
#elif defined(USE_MLU) || defined(USE_CUDA)
99+
push_cumsum(vec, len);
100+
#endif
101+
}
102+
103+
// Update kv_seq_lens_vec and kv_max_seq_len
104+
void update_kv_seq_lens_and_max(std::vector<int32_t>& kv_seq_lens_vec,
105+
int32_t kv_len,
106+
int32_t& kv_max_seq_len) {
107+
// Update max (same logic for both platforms)
108+
if (kv_len > kv_max_seq_len) {
109+
kv_max_seq_len = kv_len;
110+
}
111+
// Update kv_seq_lens_vec
112+
append_seq_len(kv_seq_lens_vec, kv_len);
113+
}
114+
115+
// Batch expansion strategy for validation
116+
void batch_expansion_process_seq_lens(
117+
std::vector<int32_t>& kv_seq_lens_vec,
118+
std::vector<int32_t>& q_seq_lens_vec,
119+
std::vector<std::vector<int32_t>>& block_tables_vec,
120+
int32_t& kv_max_seq_len,
121+
const Slice<int32_t>& kv_seq_lens_slice,
122+
const Slice<int32_t>& block_table_slice,
123+
int32_t seq_id,
124+
int32_t position_offset,
125+
int32_t num_val_tokens) {
126+
for (int32_t offset = position_offset;
127+
offset < num_val_tokens + position_offset;
128+
++offset) {
129+
// Calculate kv length and update kv_seq_lens_vec and kv_max_seq_len
130+
int32_t kv_len = calculate_kv_len(kv_seq_lens_slice, seq_id, offset);
131+
update_kv_seq_lens_and_max(kv_seq_lens_vec, kv_len, kv_max_seq_len);
132+
// Append sequence length of 1 to q_seq_lens_vec
133+
// for batch expansion strategy for validation
134+
append_seq_len(q_seq_lens_vec, 1);
135+
// Append block table to block_tables_vec
136+
block_tables_vec.emplace_back(block_table_slice);
137+
}
138+
}
139+
140+
// Update kv_seq_lens_vec based on platform type
141+
// For NPU: directly add kv_seq_lens_slice[seq_id] + offset
142+
// For others: build cumulative format
143+
// Also updates kv_max_seq_len to track the maximum sequence length
144+
void update_kv_seq_lens_vec(std::vector<int32_t>& kv_seq_lens_vec,
145+
const Slice<int32_t>& kv_seq_lens_slice,
146+
int32_t seq_id,
147+
int32_t offset,
148+
int32_t& kv_max_seq_len) {
149+
int32_t kv_len = calculate_kv_len(kv_seq_lens_slice, seq_id, offset);
150+
update_kv_seq_lens_and_max(kv_seq_lens_vec, kv_len, kv_max_seq_len);
151+
}
152+
61153
} // namespace
62154

63155
SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
@@ -68,6 +160,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
68160
runtime_options.enable_schedule_overlap(false);
69161
impl_ =
70162
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
163+
// here we specify num speculative tokens to 0 to pass the indication of
164+
// draft model to worker when enable_speculative_decode.
165+
// NOTE: If you want to modify this part, make sure you also check the usage
166+
// of
167+
// num_speculative_tokens in draft model.
71168
runtime_options.num_decoding_tokens(1).num_speculative_tokens(0);
72169
draft_impl_ =
73170
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
@@ -196,13 +293,15 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
196293

197294
// prepare input for draft model
198295
auto& embeddings = output.sample_output.embeddings;
199-
auto next_tokens = safe_to(output.sample_output.next_tokens, torch::kInt);
296+
auto next_tokens = ensure_int64_for_certain_platform(
297+
safe_to(output.sample_output.next_tokens, torch::kInt));
200298

201299
if (embeddings.defined()) {
202300
prefill_input.input_params.input_embedding = embeddings.clone();
203301
}
204302
if (next_tokens.defined()) {
205303
auto& token_ids = prefill_input.token_ids;
304+
token_ids = ensure_int64_for_certain_platform(token_ids);
206305
auto mask = (token_ids == -1);
207306
token_ids.masked_scatter_(mask, next_tokens);
208307
}
@@ -259,7 +358,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
259358
new_token_ids.reserve(input.token_ids.numel());
260359
for (size_t i = 0; i < input_params.num_sequences; ++i) {
261360
int32_t q_len = 0;
262-
q_len = input_params.q_seq_lens_vec[i];
361+
q_len = input_params.get_q_seq_len(i);
263362
Slice<int32_t> tokens_ids_slice_i =
264363
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
265364
start_idx += q_len;
@@ -316,9 +415,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
316415

317416
for (int i = 0; i < options_.num_speculative_tokens(); ++i) {
318417
ForwardOutput draft_output = draft_outputs[i];
319-
auto next_tokens =
320-
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
418+
auto next_tokens = ensure_int64_for_certain_platform(
419+
safe_to(draft_output.sample_output.next_tokens, torch::kInt));
321420
auto& token_ids = validate_input.token_ids;
421+
token_ids = ensure_int64_for_certain_platform(token_ids);
322422
auto mask = (token_ids == -1 * (i + 1));
323423
token_ids.masked_scatter_(mask, next_tokens);
324424
}
@@ -381,9 +481,13 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
381481
Slice<int32_t> kv_seq_lens_slice = input_params.kv_seq_lens_vec;
382482
torch::Tensor block_tables = safe_to(input_params.block_tables, torch::kCPU);
383483

484+
// Initialize kv_max_seq_len to 0
485+
int32_t kv_max_seq_len = 0;
486+
384487
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
385488
new_positions.emplace_back(positions_slice[seq_id] + offset);
386-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
489+
update_kv_seq_lens_vec(
490+
kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset, kv_max_seq_len);
387491
torch::Tensor block_table = block_tables[seq_id];
388492
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
389493
block_table.numel()};
@@ -394,7 +498,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
394498

395499
draft_input.positions = torch::tensor(new_positions, int_options);
396500
// update the input_params
397-
input_params.kv_max_seq_len = input_params.kv_max_seq_len + offset;
501+
input_params.kv_max_seq_len = kv_max_seq_len;
398502
input_params.kv_seq_lens_vec = kv_seq_lens_vec;
399503
input_params.kv_seq_lens = torch::tensor(kv_seq_lens_vec, int_options);
400504
input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options);
@@ -438,6 +542,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
438542
std::vector<int32_t> new_token_slot_ids;
439543
std::vector<std::vector<int32_t>> block_tables_vec;
440544

545+
int32_t kv_max_seq_len = 0;
441546
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
442547
new_token_ids.emplace_back(tokens_ids_slice[seq_id]);
443548
new_positions.emplace_back(positions_slice[seq_id] + position_offset);
@@ -453,17 +558,27 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
453558

454559
// process kv length and q length
455560
if (FLAGS_enable_atb_spec_kernel) {
561+
// expand the num of decode tokens for each batch in the batch for
562+
// validation
456563
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] +
457564
num_speculative_tokens + position_offset);
458565
q_seq_lens_vec.emplace_back(num_val_tokens);
459-
} else {
460-
for (int32_t offset = position_offset;
461-
offset < num_val_tokens + position_offset;
462-
++offset) {
463-
q_seq_lens_vec.emplace_back(1);
464-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + offset);
465-
block_tables_vec.emplace_back(block_table_slice);
566+
// update max for NPU: direct format, compare with new value
567+
if (kv_seq_lens_vec.back() > kv_max_seq_len) {
568+
kv_max_seq_len = kv_seq_lens_vec.back();
466569
}
570+
} else {
571+
// expand the batch sizes for validation
572+
// and update max for MLU/CUDA: cumulative format, compare with new value
573+
batch_expansion_process_seq_lens(kv_seq_lens_vec,
574+
q_seq_lens_vec,
575+
block_tables_vec,
576+
kv_max_seq_len,
577+
kv_seq_lens_slice,
578+
block_table_slice,
579+
seq_id,
580+
position_offset,
581+
num_val_tokens);
467582
}
468583

469584
// process slot id
@@ -490,8 +605,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
490605
input_params.q_seq_lens_vec = std::move(q_seq_lens_vec);
491606
input_params.q_seq_lens =
492607
torch::tensor(input_params.q_seq_lens_vec, int_options);
493-
input_params.kv_max_seq_len =
494-
*std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end());
608+
input_params.kv_max_seq_len = kv_max_seq_len;
495609
input_params.kv_seq_lens_vec = std::move(kv_seq_lens_vec);
496610
input_params.kv_seq_lens =
497611
torch::tensor(input_params.kv_seq_lens_vec, int_options);
@@ -573,6 +687,7 @@ SampleOutput SpeculativeWorkerImpl::validate(
573687
size_t num_draft_tokens = num_target_tokens - batch_size;
574688
COUNTER_ADD(speculative_num_draft_tokens_total, num_draft_tokens);
575689
COUNTER_ADD(speculative_num_accepted_tokens_total, num_draft_tokens - count);
690+
576691
return sample_output;
577692
}
578693

@@ -591,11 +706,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
591706
torch::Tensor positions = safe_to(inputs.positions, torch::kCPU);
592707
Slice<int32_t> positions_slice = {positions.data_ptr<int32_t>(),
593708
positions.numel()};
709+
// Get the tokens generated in the last step (flattened for easier indexing)
594710
torch::Tensor last_token_ids = safe_to(
595711
last_step_output_.sample_output.next_tokens.flatten(), torch::kCPU);
596712
Slice<int64_t> last_tokens_ids_slice = {last_token_ids.data_ptr<int64_t>(),
597713
last_token_ids.numel()};
598714

715+
// Determine how many tokens were decoded in the last step
716+
// If the output is 2D, it means multiple tokens were generated per sequence
599717
int32_t last_step_decode_num = 1;
600718
if (last_step_output_.sample_output.next_tokens.dim() == 2) {
601719
last_step_decode_num = last_step_output_.sample_output.next_tokens.size(1);
@@ -613,13 +731,23 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
613731
kv_seq_lens_vec.reserve(num_sequences);
614732
new_token_slot_ids.reserve(num_sequences);
615733

616-
// get right token id and position
734+
// Initialize kv_max_seq_len to 0
735+
int32_t kv_max_seq_len = 0;
736+
737+
// Process each sequence to get the correct token ID and position for the next
738+
// step
617739
for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) {
618740
int32_t postion_offset = 0;
619741
int32_t last_step_token_id = 0;
742+
743+
// If the token ID is non-negative, it's a direct token ID (not a
744+
// placeholder)
620745
if (tokens_ids_slice[seq_id] >= 0) {
621746
last_step_token_id = tokens_ids_slice[seq_id];
622747
} else {
748+
// Negative token IDs are placeholders that need to be resolved from
749+
// last_step_output_ The absolute value minus 1 gives the index into the
750+
// last step's output
623751
int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1;
624752
last_step_index = last_step_index * last_step_decode_num;
625753
postion_offset = -1;
@@ -634,8 +762,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
634762

635763
new_token_ids.emplace_back(last_step_token_id);
636764
new_positions.emplace_back(positions_slice[seq_id] + postion_offset);
637-
kv_seq_lens_vec.emplace_back(kv_seq_lens_slice[seq_id] + postion_offset);
638-
765+
update_kv_seq_lens_vec(kv_seq_lens_vec,
766+
kv_seq_lens_slice,
767+
seq_id,
768+
postion_offset,
769+
kv_max_seq_len);
770+
771+
// Calculate the new cache slot ID based on the position offset
772+
// This handles cases where we need to move to a different block
639773
torch::Tensor block_table = block_tables[seq_id];
640774
Slice<int32_t> block_table_slice = {block_table.data_ptr<int32_t>(),
641775
block_table.numel()};
@@ -644,12 +778,12 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
644778
new_token_slot_ids.emplace_back(slot_id);
645779
}
646780

781+
// Create new tensors with updated values
647782
torch::TensorOptions int_options = inputs.token_ids.options();
648783
new_inputs.token_ids = torch::tensor(new_token_ids, int_options);
649784
new_inputs.positions = torch::tensor(new_positions, int_options);
650785
// update the input_params
651-
input_params.kv_max_seq_len =
652-
*std::max_element(kv_seq_lens_vec.begin(), kv_seq_lens_vec.end());
786+
input_params.kv_max_seq_len = kv_max_seq_len;
653787
input_params.kv_seq_lens_vec = std::move(kv_seq_lens_vec);
654788
input_params.kv_seq_lens =
655789
torch::tensor(input_params.kv_seq_lens_vec, int_options);

xllm/core/runtime/worker_impl.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,26 @@ bool WorkerImpl::init_model(const std::string& model_weights_path,
607607
}
608608
}
609609

610+
#if defined(USE_NPU)
610611
if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) {
611612
args.num_speculative_tokens(options_.num_speculative_tokens());
612613
}
614+
#else
615+
if (options_.enable_speculative_decode()) {
616+
args.num_speculative_tokens(options_.num_speculative_tokens());
617+
// When running speculative decoding, the draft worker reuses the same
618+
// checkpoint as the target DeepSeek V3/V32 model. The draft worker needs to
619+
// instantiate the MTP variant, so override the model_type here without
620+
// mutating the original config.
621+
if (options_.num_speculative_tokens() == 0 &&
622+
(args.model_type() == "deepseek_v3" ||
623+
args.model_type() == "deepseek_v32")) {
624+
LOG(INFO) << "Overriding draft model_type from " << args.model_type()
625+
<< " to deepseek_mtp for speculative decoding";
626+
args.model_type("deepseek_mtp");
627+
}
628+
}
629+
#endif
613630

614631
// create model context
615632
dtype_ = dtype;

xllm/models/llm/llm_model_base.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,9 @@ class LlmForCausalLMImplBase : public torch::nn::Module {
393393
#endif
394394
}
395395

396-
void load_model(std::unique_ptr<ModelLoader> loader,
397-
std::string prefix = "model." /*llm model weight prefix*/) {
396+
virtual void load_model(
397+
std::unique_ptr<ModelLoader> loader,
398+
std::string prefix = "model." /*llm model weight prefix*/) {
398399
for (const auto& state_dict : loader->get_state_dicts()) {
399400
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
400401
if (tie_word_embeddings) {

0 commit comments

Comments
 (0)