@@ -58,6 +58,98 @@ int32_t kv_cache_slot_id(int32_t position,
5858 return block_id * block_size + block_offset;
5959}
6060
61+ // Convert tensor to int64 for MLU platform (temp workaround)
62+ // MLU will support int32 for masked_scatter in the future
63+ torch::Tensor ensure_int64_for_certain_platform (torch::Tensor tensor) {
64+ #if defined(USE_MLU)
65+ return tensor.to (torch::kInt64 );
66+ #else
67+ return tensor;
68+ #endif
69+ }
70+
71+ // Push cumulative sum to vector (used for cumulative format)
72+ void push_cumsum (std::vector<int32_t >& vec, int32_t len) {
73+ if (vec.empty ()) {
74+ vec.emplace_back (0 );
75+ }
76+ vec.emplace_back (vec.back () + len);
77+ }
78+
79+ // Calculate actual kv_len based on platform type
80+ // For NPU: direct format - returns kv_seq_lens_slice[seq_id] + offset
81+ // For MLU/CUDA: cumulative format - returns the actual length increment
82+ int32_t calculate_kv_len (const Slice<int32_t >& kv_seq_lens_slice,
83+ int32_t seq_id,
84+ int32_t offset) {
85+ #if defined(USE_NPU)
86+ return kv_seq_lens_slice[seq_id] + offset;
87+ #elif defined(USE_MLU) || defined(USE_CUDA)
88+ return kv_seq_lens_slice[seq_id + 1 ] - kv_seq_lens_slice[seq_id] + offset;
89+ #endif
90+ }
91+
92+ // Append sequence length to vector based on platform type
93+ // For NPU: directly add the len value
94+ // For MLU/CUDA: add using cumulative format
95+ void append_seq_len (std::vector<int32_t >& vec, int32_t len) {
96+ #if defined(USE_NPU)
97+ vec.emplace_back (len);
98+ #elif defined(USE_MLU) || defined(USE_CUDA)
99+ push_cumsum (vec, len);
100+ #endif
101+ }
102+
103+ // Update kv_seq_lens_vec and kv_max_seq_len
104+ void update_kv_seq_lens_and_max (std::vector<int32_t >& kv_seq_lens_vec,
105+ int32_t kv_len,
106+ int32_t & kv_max_seq_len) {
107+ // Update max (same logic for both platforms)
108+ if (kv_len > kv_max_seq_len) {
109+ kv_max_seq_len = kv_len;
110+ }
111+ // Update kv_seq_lens_vec
112+ append_seq_len (kv_seq_lens_vec, kv_len);
113+ }
114+
115+ // Batch expansion strategy for validation
116+ void batch_expansion_process_seq_lens (
117+ std::vector<int32_t >& kv_seq_lens_vec,
118+ std::vector<int32_t >& q_seq_lens_vec,
119+ std::vector<std::vector<int32_t >>& block_tables_vec,
120+ int32_t & kv_max_seq_len,
121+ const Slice<int32_t >& kv_seq_lens_slice,
122+ const Slice<int32_t >& block_table_slice,
123+ int32_t seq_id,
124+ int32_t position_offset,
125+ int32_t num_val_tokens) {
126+ for (int32_t offset = position_offset;
127+ offset < num_val_tokens + position_offset;
128+ ++offset) {
129+ // Calculate kv length and update kv_seq_lens_vec and kv_max_seq_len
130+ int32_t kv_len = calculate_kv_len (kv_seq_lens_slice, seq_id, offset);
131+ update_kv_seq_lens_and_max (kv_seq_lens_vec, kv_len, kv_max_seq_len);
132+ // Append sequence length of 1 to q_seq_lens_vec
133+ // for batch expansion strategy for validation
134+ append_seq_len (q_seq_lens_vec, 1 );
135+ // Append block table to block_tables_vec
136+ block_tables_vec.emplace_back (block_table_slice);
137+ }
138+ }
139+
140+ // Update kv_seq_lens_vec based on platform type
141+ // For NPU: directly add kv_seq_lens_slice[seq_id] + offset
142+ // For others: build cumulative format
143+ // Also updates kv_max_seq_len to track the maximum sequence length
144+ void update_kv_seq_lens_vec (std::vector<int32_t >& kv_seq_lens_vec,
145+ const Slice<int32_t >& kv_seq_lens_slice,
146+ int32_t seq_id,
147+ int32_t offset,
148+ int32_t & kv_max_seq_len) {
149+ int32_t kv_len = calculate_kv_len (kv_seq_lens_slice, seq_id, offset);
150+ update_kv_seq_lens_and_max (kv_seq_lens_vec, kv_len, kv_max_seq_len);
151+ }
152+
61153} // namespace
62154
63155SpeculativeWorkerImpl::SpeculativeWorkerImpl (const ParallelArgs& parallel_args,
@@ -68,6 +160,11 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
68160 runtime_options.enable_schedule_overlap (false );
69161 impl_ =
70162 std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
163+ // here we specify num speculative tokens to 0 to pass the indication of
164+ // draft model to worker when enable_speculative_decode.
165+ // NOTE: If you want to modify this part, make sure you also check the usage
166+ // of
167+ // num_speculative_tokens in draft model.
71168 runtime_options.num_decoding_tokens (1 ).num_speculative_tokens (0 );
72169 draft_impl_ =
73170 std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
@@ -196,13 +293,15 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
196293
197294 // prepare input for draft model
198295 auto & embeddings = output.sample_output .embeddings ;
199- auto next_tokens = safe_to (output.sample_output .next_tokens , torch::kInt );
296+ auto next_tokens = ensure_int64_for_certain_platform (
297+ safe_to (output.sample_output .next_tokens , torch::kInt ));
200298
201299 if (embeddings.defined ()) {
202300 prefill_input.input_params .input_embedding = embeddings.clone ();
203301 }
204302 if (next_tokens.defined ()) {
205303 auto & token_ids = prefill_input.token_ids ;
304+ token_ids = ensure_int64_for_certain_platform (token_ids);
206305 auto mask = (token_ids == -1 );
207306 token_ids.masked_scatter_ (mask, next_tokens);
208307 }
@@ -259,7 +358,7 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
259358 new_token_ids.reserve (input.token_ids .numel ());
260359 for (size_t i = 0 ; i < input_params.num_sequences ; ++i) {
261360 int32_t q_len = 0 ;
262- q_len = input_params.q_seq_lens_vec [i] ;
361+ q_len = input_params.get_q_seq_len (i) ;
263362 Slice<int32_t > tokens_ids_slice_i =
264363 tokens_ids_slice.slice (start_idx + 1 , start_idx + q_len);
265364 start_idx += q_len;
@@ -316,9 +415,10 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
316415
317416 for (int i = 0 ; i < options_.num_speculative_tokens (); ++i) {
318417 ForwardOutput draft_output = draft_outputs[i];
319- auto next_tokens =
320- safe_to (draft_output.sample_output .next_tokens , torch::kInt );
418+ auto next_tokens = ensure_int64_for_certain_platform (
419+ safe_to (draft_output.sample_output .next_tokens , torch::kInt )) ;
321420 auto & token_ids = validate_input.token_ids ;
421+ token_ids = ensure_int64_for_certain_platform (token_ids);
322422 auto mask = (token_ids == -1 * (i + 1 ));
323423 token_ids.masked_scatter_ (mask, next_tokens);
324424 }
@@ -381,9 +481,13 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
381481 Slice<int32_t > kv_seq_lens_slice = input_params.kv_seq_lens_vec ;
382482 torch::Tensor block_tables = safe_to (input_params.block_tables , torch::kCPU );
383483
484+ // Initialize kv_max_seq_len to 0
485+ int32_t kv_max_seq_len = 0 ;
486+
384487 for (int32_t seq_id = 0 ; seq_id < num_sequences; ++seq_id) {
385488 new_positions.emplace_back (positions_slice[seq_id] + offset);
386- kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + offset);
489+ update_kv_seq_lens_vec (
490+ kv_seq_lens_vec, kv_seq_lens_slice, seq_id, offset, kv_max_seq_len);
387491 torch::Tensor block_table = block_tables[seq_id];
388492 Slice<int32_t > block_table_slice = {block_table.data_ptr <int32_t >(),
389493 block_table.numel ()};
@@ -394,7 +498,7 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(const ForwardInput& input,
394498
395499 draft_input.positions = torch::tensor (new_positions, int_options);
396500 // update the input_params
397- input_params.kv_max_seq_len = input_params. kv_max_seq_len + offset ;
501+ input_params.kv_max_seq_len = kv_max_seq_len;
398502 input_params.kv_seq_lens_vec = kv_seq_lens_vec;
399503 input_params.kv_seq_lens = torch::tensor (kv_seq_lens_vec, int_options);
400504 input_params.new_cache_slots = torch::tensor (new_token_slot_ids, int_options);
@@ -438,6 +542,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
438542 std::vector<int32_t > new_token_slot_ids;
439543 std::vector<std::vector<int32_t >> block_tables_vec;
440544
545+ int32_t kv_max_seq_len = 0 ;
441546 for (int32_t seq_id = 0 ; seq_id < num_sequences; ++seq_id) {
442547 new_token_ids.emplace_back (tokens_ids_slice[seq_id]);
443548 new_positions.emplace_back (positions_slice[seq_id] + position_offset);
@@ -453,17 +558,27 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
453558
454559 // process kv length and q length
455560 if (FLAGS_enable_atb_spec_kernel) {
561+ // expand the num of decode tokens for each batch in the batch for
562+ // validation
456563 kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] +
457564 num_speculative_tokens + position_offset);
458565 q_seq_lens_vec.emplace_back (num_val_tokens);
459- } else {
460- for (int32_t offset = position_offset;
461- offset < num_val_tokens + position_offset;
462- ++offset) {
463- q_seq_lens_vec.emplace_back (1 );
464- kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + offset);
465- block_tables_vec.emplace_back (block_table_slice);
566+ // update max for NPU: direct format, compare with new value
567+ if (kv_seq_lens_vec.back () > kv_max_seq_len) {
568+ kv_max_seq_len = kv_seq_lens_vec.back ();
466569 }
570+ } else {
571+ // expand the batch sizes for validation
572+ // and update max for MLU/CUDA: cumulative format, compare with new value
573+ batch_expansion_process_seq_lens (kv_seq_lens_vec,
574+ q_seq_lens_vec,
575+ block_tables_vec,
576+ kv_max_seq_len,
577+ kv_seq_lens_slice,
578+ block_table_slice,
579+ seq_id,
580+ position_offset,
581+ num_val_tokens);
467582 }
468583
469584 // process slot id
@@ -490,8 +605,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
490605 input_params.q_seq_lens_vec = std::move (q_seq_lens_vec);
491606 input_params.q_seq_lens =
492607 torch::tensor (input_params.q_seq_lens_vec , int_options);
493- input_params.kv_max_seq_len =
494- *std::max_element (kv_seq_lens_vec.begin (), kv_seq_lens_vec.end ());
608+ input_params.kv_max_seq_len = kv_max_seq_len;
495609 input_params.kv_seq_lens_vec = std::move (kv_seq_lens_vec);
496610 input_params.kv_seq_lens =
497611 torch::tensor (input_params.kv_seq_lens_vec , int_options);
@@ -573,6 +687,7 @@ SampleOutput SpeculativeWorkerImpl::validate(
573687 size_t num_draft_tokens = num_target_tokens - batch_size;
574688 COUNTER_ADD (speculative_num_draft_tokens_total, num_draft_tokens);
575689 COUNTER_ADD (speculative_num_accepted_tokens_total, num_draft_tokens - count);
690+
576691 return sample_output;
577692}
578693
@@ -591,11 +706,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
591706 torch::Tensor positions = safe_to (inputs.positions , torch::kCPU );
592707 Slice<int32_t > positions_slice = {positions.data_ptr <int32_t >(),
593708 positions.numel ()};
709+ // Get the tokens generated in the last step (flattened for easier indexing)
594710 torch::Tensor last_token_ids = safe_to (
595711 last_step_output_.sample_output .next_tokens .flatten (), torch::kCPU );
596712 Slice<int64_t > last_tokens_ids_slice = {last_token_ids.data_ptr <int64_t >(),
597713 last_token_ids.numel ()};
598714
715+ // Determine how many tokens were decoded in the last step
716+ // If the output is 2D, it means multiple tokens were generated per sequence
599717 int32_t last_step_decode_num = 1 ;
600718 if (last_step_output_.sample_output .next_tokens .dim () == 2 ) {
601719 last_step_decode_num = last_step_output_.sample_output .next_tokens .size (1 );
@@ -613,13 +731,23 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
613731 kv_seq_lens_vec.reserve (num_sequences);
614732 new_token_slot_ids.reserve (num_sequences);
615733
616- // get right token id and position
734+ // Initialize kv_max_seq_len to 0
735+ int32_t kv_max_seq_len = 0 ;
736+
737+ // Process each sequence to get the correct token ID and position for the next
738+ // step
617739 for (int32_t seq_id = 0 ; seq_id < num_sequences; ++seq_id) {
618740 int32_t postion_offset = 0 ;
619741 int32_t last_step_token_id = 0 ;
742+
743+ // If the token ID is non-negative, it's a direct token ID (not a
744+ // placeholder)
620745 if (tokens_ids_slice[seq_id] >= 0 ) {
621746 last_step_token_id = tokens_ids_slice[seq_id];
622747 } else {
748+ // Negative token IDs are placeholders that need to be resolved from
749+ // last_step_output_ The absolute value minus 1 gives the index into the
750+ // last step's output
623751 int32_t last_step_index = -1 * tokens_ids_slice[seq_id] - 1 ;
624752 last_step_index = last_step_index * last_step_decode_num;
625753 postion_offset = -1 ;
@@ -634,8 +762,14 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
634762
635763 new_token_ids.emplace_back (last_step_token_id);
636764 new_positions.emplace_back (positions_slice[seq_id] + postion_offset);
637- kv_seq_lens_vec.emplace_back (kv_seq_lens_slice[seq_id] + postion_offset);
638-
765+ update_kv_seq_lens_vec (kv_seq_lens_vec,
766+ kv_seq_lens_slice,
767+ seq_id,
768+ postion_offset,
769+ kv_max_seq_len);
770+
771+ // Calculate the new cache slot ID based on the position offset
772+ // This handles cases where we need to move to a different block
639773 torch::Tensor block_table = block_tables[seq_id];
640774 Slice<int32_t > block_table_slice = {block_table.data_ptr <int32_t >(),
641775 block_table.numel ()};
@@ -644,12 +778,12 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output(
644778 new_token_slot_ids.emplace_back (slot_id);
645779 }
646780
781+ // Create new tensors with updated values
647782 torch::TensorOptions int_options = inputs.token_ids .options ();
648783 new_inputs.token_ids = torch::tensor (new_token_ids, int_options);
649784 new_inputs.positions = torch::tensor (new_positions, int_options);
650785 // update the input_params
651- input_params.kv_max_seq_len =
652- *std::max_element (kv_seq_lens_vec.begin (), kv_seq_lens_vec.end ());
786+ input_params.kv_max_seq_len = kv_max_seq_len;
653787 input_params.kv_seq_lens_vec = std::move (kv_seq_lens_vec);
654788 input_params.kv_seq_lens =
655789 torch::tensor (input_params.kv_seq_lens_vec , int_options);
0 commit comments