Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions vllm/v1/core/sched/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
Expand All @@ -18,6 +19,20 @@
from vllm.v1.structured_output import StructuredOutputManager


class PauseState(enum.IntEnum):
"""Scheduler pause state.

- UNPAUSED: Normal operation
- PAUSE_NEW: No new requests are scheduled, requests already in
running state are scheduled.
- PAUSE_ALL: No requests are scheduled
"""

UNPAUSED = 0
PAUSED_NEW = 1
PAUSED_ALL = 2


class SchedulerInterface(ABC):
@abstractmethod
def __init__(
Expand Down Expand Up @@ -120,9 +135,9 @@ def add_request(self, request: "Request") -> None:
@abstractmethod
def finish_requests(
self,
request_ids: str | Iterable[str],
request_ids: str | Iterable[str] | None,
finished_status: "RequestStatus",
) -> list[str]:
) -> list[tuple[str, int]]:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing for that request.

Expand All @@ -132,12 +147,12 @@ def finish_requests(
de-tokenizing its generated tokens.

Args:
request_ids: A single or a list of request IDs.
request_ids: A single or a list of request IDs, or None to finish all.
finished_status: The finished status of the given requests.

Returns:
List of request IDs that were actually finished (were in the
scheduler and not already finished).
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
"""
raise NotImplementedError

Expand Down Expand Up @@ -172,16 +187,7 @@ def has_requests(self) -> bool:
return self.has_unfinished_requests() or self.has_finished_requests()

@abstractmethod
def get_all_request_ids(self) -> list[str]:
"""Return all request IDs currently in the scheduler (running or waiting)."""
raise NotImplementedError

def get_request_client_indices(
self, request_ids: "Iterable[str]"
) -> "dict[str, int]":
"""Return request_id -> client_index for requests that exist and are not
finished. Used to route abort outputs to the correct client(s).
"""
def set_pause_state(self, pause_state: PauseState) -> None:
raise NotImplementedError

@abstractmethod
Expand Down
62 changes: 33 additions & 29 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
Expand Down Expand Up @@ -271,6 +271,8 @@ def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
vllm_config=self.vllm_config,
)

self.pause_state: PauseState = PauseState.UNPAUSED

def _mamba_block_aligned_split(
self,
request: Request,
Expand Down Expand Up @@ -338,6 +340,10 @@ def schedule(self) -> SchedulerOutput:
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
if self.pause_state == PauseState.PAUSED_ALL:
# Do not schedule any requests when paused.
token_budget = 0

# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
Expand Down Expand Up @@ -525,12 +531,12 @@ def schedule(self) -> SchedulerOutput:
)
assert len(scheduled_loras) <= self.lora_config.max_loras

# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)

# Next, schedule the WAITING requests.
if not preempted_reqs:
if not preempted_reqs and self.pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)

while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
Expand Down Expand Up @@ -797,9 +803,10 @@ def schedule(self) -> SchedulerOutput:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)

# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)

# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
Expand Down Expand Up @@ -1641,20 +1648,6 @@ def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)

def get_all_request_ids(self) -> list[str]:
"""Return all request IDs currently in the scheduler (running or waiting)."""
return list(self.requests.keys())

def get_request_client_indices(self, request_ids: Iterable[str]) -> dict[str, int]:
"""Return request_id -> client_index for requests that exist and are not
finished.
"""
return {
rid: self.requests[rid].client_index
for rid in request_ids
if rid in self.requests and not self.requests[rid].is_finished()
}

def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
if existing is not None:
Expand All @@ -1678,22 +1671,26 @@ def add_request(self, request: Request) -> None:
request.record_event(EngineCoreEventType.QUEUED)

def finish_requests(
self, request_ids: str | Iterable[str], finished_status: RequestStatus
) -> list[str]:
self, request_ids: str | Iterable[str] | None, finished_status: RequestStatus
) -> list[tuple[str, int]]:
"""Handles the finish signal from outside the scheduler.

For example, the API server can abort a request when the client
disconnects.

If request_ids is None, all requests will be finished.

Returns:
List of request IDs that were actually finished (were in the
scheduler and not already finished).
Tuple of (req_id, client_index) for requests that were aborted. Will not
include any that were already finished.
"""
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids,)
else:
elif request_ids is not None:
request_ids = set(request_ids)
else:
request_ids = self.requests.keys()

running_requests_to_remove = set()
waiting_requests_to_remove = []
Expand Down Expand Up @@ -1733,7 +1730,7 @@ def finish_requests(
request.status = finished_status
self._free_request(request, delay_free_blocks=delay_free_blocks)

return [r.request_id for r in valid_requests]
return [(r.request_id, r.client_index) for r in valid_requests]

def _free_request(
self, request: Request, delay_free_blocks: bool = False
Expand All @@ -1758,7 +1755,14 @@ def _free_blocks(self, request: Request):
self.kv_cache_manager.free(request)
del self.requests[request.request_id]

def set_pause_state(self, pause_state: PauseState) -> None:
self.pause_state = pause_state

def get_num_unfinished_requests(self) -> int:
if self.pause_state == PauseState.PAUSED_ALL:
return 0
if self.pause_state == PauseState.PAUSED_NEW:
return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
return num_waiting + len(self.running)

Expand Down
68 changes: 1 addition & 67 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

import enum
import time
from collections.abc import Callable, Mapping
from concurrent.futures import Future
from collections.abc import Mapping
from typing import Any, Literal

import msgspec
import numpy as np
import torch

from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
Expand All @@ -20,31 +18,12 @@
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult

logger = init_logger(__name__)

# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]


class PauseState(enum.IntEnum):
"""Engine scheduler pause state. All states besides UNPAUSED add
new requests to a queue that is flushed on resume.

- UNPAUSED: Normal operation; step runs, adds go to scheduler.
- PAUSE_ABORT: Paused (no step)
- PAUSE_KEEP: Paused (no step)
- PAUSE_WAIT: Draining in-flight (step runs)
"""

UNPAUSED = 0
PAUSE_ABORT = 1
PAUSE_KEEP = 2
PAUSE_WAIT = 3


# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
Expand Down Expand Up @@ -182,51 +161,6 @@ def finished(self) -> bool:
return self.finish_reason is not None


class UtilityFuture:
"""
Standard type for deferred utility completion. Utilities return this;
the engine calls register_done() with request context and step() each loop.
When the future completes, the same done behavior runs (set output, put on queue).
"""

def __init__(
self,
future: Future[Any],
step_fn: Callable[[], None],
) -> None:
self.future = future
self._step_fn = step_fn

def register_done(
self,
output_queue: Any,
client_idx: int,
call_id: int,
method_name: str,
) -> None:
"""Register the standard done behavior with request context."""
future = self.future

def _done(f: Future[Any]) -> None:
output = UtilityOutput(call_id)
try:
output.result = UtilityResult(f.result())
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (
f"Call to {method_name} method failed: {str(e)}"
)
output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output))
)

future.add_done_callback(_done)

def step(self) -> None:
"""Run one step; the callback may complete the future when ready."""
self._step_fn()


class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ async def abort(
self, request_id: str | Iterable[str], internal: bool = False
) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""

request_ids = (
(request_id,) if isinstance(request_id, str) else as_list(request_id)
)
Expand All @@ -729,8 +730,8 @@ async def pause_generation(
Pause generation to allow model weight updates.

All mode handling (abort / wait / keep) and cache clearing is done
in the engine. New generation/encoding requests are queued by the
engine while paused and flushed on resume.
in the engine. New generation/encoding requests will not be scheduled
until resume is called.

Args:
mode: How to handle in-flight requests:
Expand Down
Loading