@@ -87,6 +87,13 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options)
8787 instance_info_.name = options_.instance_name ().value_or (" " );
8888 instance_info_.type = options_.instance_role ().value ().to_string ();
8989 instance_info_.dp_size = options.dp_size ();
90+
91+ if (options_.enable_schedule_overlap ()) {
92+ min_speculative_tokens_required_ = options_.num_speculative_tokens () * 2 ;
93+ } else {
94+ min_speculative_tokens_required_ = options_.num_speculative_tokens ();
95+ }
96+
9097}
9198
9299ContinuousScheduler::~ContinuousScheduler () { running_requests_.clear (); }
@@ -366,7 +373,7 @@ void ContinuousScheduler::handle_decode_requests(
366373 size_t & num_online_decode_preempt_offline_requests,
367374 std::unique_ptr<DecodePriorityQueue>& running_queue) {
368375 while (!running_queue->empty () &&
369- remaining_token_budget > options_. num_speculative_tokens () * 2 &&
376+ remaining_token_budget > min_speculative_tokens_required_ &&
370377 latency_budget > estimate_latency && remaining_seq_budget > 0 ) {
371378 std::shared_ptr<Request> request = running_queue->top ();
372379 // TODO: check if request is timeout
@@ -402,15 +409,15 @@ void ContinuousScheduler::handle_decode_requests(
402409 break ;
403410 }
404411 }
405- if (allocated_tokens + options_. num_speculative_tokens () * 2 >=
412+ if (allocated_tokens + min_speculative_tokens_required_ >=
406413 remaining_token_budget ||
407414 allocated_seqs >= remaining_seq_budget) {
408415 has_enough_budget = false ;
409416 break ;
410417 }
411418 // sequence token already appended
412419 size_t updated_num_tokens =
413- sequence->num_tokens () + options_. num_speculative_tokens () * 2 ;
420+ sequence->num_tokens () + min_speculative_tokens_required_ ;
414421 // no blocks left
415422 if (!kv_cache_manager_->allocate (sequence.get (), updated_num_tokens)) {
416423 has_enough_blocks = false ;
@@ -422,12 +429,12 @@ void ContinuousScheduler::handle_decode_requests(
422429 }
423430
424431 // update the allocated tokens for the sequence
425- allocated_tokens += options_. num_speculative_tokens () * 2 + 1 ;
432+ allocated_tokens += min_speculative_tokens_required_ + 1 ;
426433 allocated_seqs += 1 ;
427434 allocated_estimate_latency += seq_estimate_latency;
428435 candidate_sequences.emplace_back (sequence.get ());
429- candidate_token_budgets.emplace_back (
430- options_. num_speculative_tokens () * 2 + 1 );
436+ candidate_token_budgets.emplace_back (min_speculative_tokens_required_ +
437+ 1 );
431438 }
432439 CHECK (allocated_tokens <= remaining_token_budget);
433440 CHECK (allocated_seqs <= remaining_seq_budget);
0 commit comments