diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 279fdcc00b3d..be33315a7552 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum from abc import ABC, abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING @@ -18,6 +19,20 @@ from vllm.v1.structured_output import StructuredOutputManager +class PauseState(enum.IntEnum): + """Scheduler pause state. + + - UNPAUSED: Normal operation + - PAUSE_NEW: No new requests are scheduled, requests already in + running state are scheduled. + - PAUSE_ALL: No requests are scheduled + """ + + UNPAUSED = 0 + PAUSED_NEW = 1 + PAUSED_ALL = 2 + + class SchedulerInterface(ABC): @abstractmethod def __init__( @@ -120,9 +135,9 @@ def add_request(self, request: "Request") -> None: @abstractmethod def finish_requests( self, - request_ids: str | Iterable[str], + request_ids: str | Iterable[str] | None, finished_status: "RequestStatus", - ) -> list[str]: + ) -> list[tuple[str, int]]: """Finish the requests in the scheduler's internal queue. If the request is not in the queue, this method will do nothing for that request. @@ -132,12 +147,12 @@ def finish_requests( de-tokenizing its generated tokens. Args: - request_ids: A single or a list of request IDs. + request_ids: A single or a list of request IDs, or None to finish all. finished_status: The finished status of the given requests. Returns: - List of request IDs that were actually finished (were in the - scheduler and not already finished). + Tuple of (req_id, client_index) for requests that were aborted. Will not + include any that were already finished. """ raise NotImplementedError @@ -172,16 +187,7 @@ def has_requests(self) -> bool: return self.has_unfinished_requests() or self.has_finished_requests() @abstractmethod - def get_all_request_ids(self) -> list[str]: - """Return all request IDs currently in the scheduler (running or waiting).""" - raise NotImplementedError - - def get_request_client_indices( - self, request_ids: "Iterable[str]" - ) -> "dict[str, int]": - """Return request_id -> client_index for requests that exist and are not - finished. Used to route abort outputs to the correct client(s). - """ + def set_pause_state(self, pause_state: PauseState) -> None: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 95d6486f2f8e..64ade2cb4370 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,7 +38,7 @@ ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector -from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, GrammarOutput, @@ -271,6 +271,8 @@ def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: vllm_config=self.vllm_config, ) + self.pause_state: PauseState = PauseState.UNPAUSED + def _mamba_block_aligned_split( self, request: Request, @@ -338,6 +340,10 @@ def schedule(self) -> SchedulerOutput: req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + if self.pause_state == PauseState.PAUSED_ALL: + # Do not schedule any requests when paused. + token_budget = 0 + # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens @@ -525,12 +531,12 @@ def schedule(self) -> SchedulerOutput: ) assert len(scheduled_loras) <= self.lora_config.max_loras - # Use a temporary RequestQueue to collect requests that need to be - # skipped and put back at the head of the waiting queue later - skipped_waiting_requests = create_request_queue(self.policy) - # Next, schedule the WAITING requests. - if not preempted_reqs: + if not preempted_reqs and self.pause_state == PauseState.UNPAUSED: + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break @@ -797,9 +803,10 @@ def schedule(self) -> SchedulerOutput: self.encoder_cache_manager.allocate(request, i) if self.ec_connector is not None: self.ec_connector.update_state_after_alloc(request, i) - # Put back any skipped requests at the head of the waiting queue - if skipped_waiting_requests: - self.waiting.prepend_requests(skipped_waiting_requests) + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -1641,20 +1648,6 @@ def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) - def get_all_request_ids(self) -> list[str]: - """Return all request IDs currently in the scheduler (running or waiting).""" - return list(self.requests.keys()) - - def get_request_client_indices(self, request_ids: Iterable[str]) -> dict[str, int]: - """Return request_id -> client_index for requests that exist and are not - finished. - """ - return { - rid: self.requests[rid].client_index - for rid in request_ids - if rid in self.requests and not self.requests[rid].is_finished() - } - def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) if existing is not None: @@ -1678,22 +1671,26 @@ def add_request(self, request: Request) -> None: request.record_event(EngineCoreEventType.QUEUED) def finish_requests( - self, request_ids: str | Iterable[str], finished_status: RequestStatus - ) -> list[str]: + self, request_ids: str | Iterable[str] | None, finished_status: RequestStatus + ) -> list[tuple[str, int]]: """Handles the finish signal from outside the scheduler. For example, the API server can abort a request when the client disconnects. + If request_ids is None, all requests will be finished. + Returns: - List of request IDs that were actually finished (were in the - scheduler and not already finished). + Tuple of (req_id, client_index) for requests that were aborted. Will not + include any that were already finished. """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): request_ids = (request_ids,) - else: + elif request_ids is not None: request_ids = set(request_ids) + else: + request_ids = self.requests.keys() running_requests_to_remove = set() waiting_requests_to_remove = [] @@ -1733,7 +1730,7 @@ def finish_requests( request.status = finished_status self._free_request(request, delay_free_blocks=delay_free_blocks) - return [r.request_id for r in valid_requests] + return [(r.request_id, r.client_index) for r in valid_requests] def _free_request( self, request: Request, delay_free_blocks: bool = False @@ -1758,7 +1755,14 @@ def _free_blocks(self, request: Request): self.kv_cache_manager.free(request) del self.requests[request.request_id] + def set_pause_state(self, pause_state: PauseState) -> None: + self.pause_state = pause_state + def get_num_unfinished_requests(self) -> int: + if self.pause_state == PauseState.PAUSED_ALL: + return 0 + if self.pause_state == PauseState.PAUSED_NEW: + return len(self.running) num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input return num_waiting + len(self.running) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2e2c5b1d55..d0b0370fb389 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,15 +3,13 @@ import enum import time -from collections.abc import Callable, Mapping -from concurrent.futures import Future +from collections.abc import Mapping from typing import Any, Literal import msgspec import numpy as np import torch -from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams @@ -20,31 +18,12 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.serial_utils import UtilityResult -logger = init_logger(__name__) - # Type for pause_generation mode parameter. # - "abort": Abort all in-flight requests immediately (default). # - "wait": Wait for in-flight requests to complete before pausing. # - "keep": Freeze requests in queue; they resume on resume_generation(). PauseMode = Literal["abort", "wait", "keep"] - -class PauseState(enum.IntEnum): - """Engine scheduler pause state. All states besides UNPAUSED add - new requests to a queue that is flushed on resume. - - - UNPAUSED: Normal operation; step runs, adds go to scheduler. - - PAUSE_ABORT: Paused (no step) - - PAUSE_KEEP: Paused (no step) - - PAUSE_WAIT: Draining in-flight (step runs) - """ - - UNPAUSED = 0 - PAUSE_ABORT = 1 - PAUSE_KEEP = 2 - PAUSE_WAIT = 3 - - # These are possible values of RequestOutput.finish_reason, # so form part of the external API. FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") @@ -182,51 +161,6 @@ def finished(self) -> bool: return self.finish_reason is not None -class UtilityFuture: - """ - Standard type for deferred utility completion. Utilities return this; - the engine calls register_done() with request context and step() each loop. - When the future completes, the same done behavior runs (set output, put on queue). - """ - - def __init__( - self, - future: Future[Any], - step_fn: Callable[[], None], - ) -> None: - self.future = future - self._step_fn = step_fn - - def register_done( - self, - output_queue: Any, - client_idx: int, - call_id: int, - method_name: str, - ) -> None: - """Register the standard done behavior with request context.""" - future = self.future - - def _done(f: Future[Any]) -> None: - output = UtilityOutput(call_id) - try: - output.result = UtilityResult(f.result()) - except BaseException as e: - logger.exception("Invocation of %s method failed", method_name) - output.failure_message = ( - f"Call to {method_name} method failed: {str(e)}" - ) - output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output)) - ) - - future.add_done_callback(_done) - - def step(self) -> None: - """Run one step; the callback may complete the future when ready.""" - self._step_fn() - - class UtilityOutput( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c88713974394..f6fdb69eeaec 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -709,6 +709,7 @@ async def abort( self, request_id: str | Iterable[str], internal: bool = False ) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" + request_ids = ( (request_id,) if isinstance(request_id, str) else as_list(request_id) ) @@ -729,8 +730,8 @@ async def pause_generation( Pause generation to allow model weight updates. All mode handling (abort / wait / keep) and cache clearing is done - in the engine. New generation/encoding requests are queued by the - engine while paused and flushed on resume. + in the engine. New generation/encoding requests will not be scheduled + until resume is called. Args: mode: How to handle in-flight requests: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 7999f61ebb5e..799a45c18e02 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,7 +5,7 @@ import signal import threading import time -from collections import deque +from collections import defaultdict, deque from collections.abc import Callable, Generator from concurrent.futures import Future from contextlib import ExitStack, contextmanager @@ -40,7 +40,7 @@ get_request_block_hasher, init_none_hash, ) -from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ( EngineCoreOutput, @@ -49,10 +49,8 @@ EngineCoreRequestType, FinishReason, PauseMode, - PauseState, ReconfigureDistributedRequest, ReconfigureRankType, - UtilityFuture, UtilityOutput, UtilityResult, ) @@ -213,11 +211,7 @@ def __init__( self.aborts_queue = queue.Queue[list[str]]() - # Pause state; all non-UNPAUSED states queue new adds in _paused_adds_queue. - self._scheduler_pause_state = PauseState.UNPAUSED - # Requests received while paused; flushed on resume_scheduler(). - # Contains tuples of (request, request_wave) - self._paused_adds_queue: deque[tuple[Request, int]] = deque() + self.per_step_hooks: set[Callable] = set() # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. @@ -297,9 +291,6 @@ def add_request(self, request: Request, request_wave: int = 0): `request_wave`: indicate which wave of requests this is expected to belong to in DP case """ - if self._scheduler_pause_state != PauseState.UNPAUSED: - self._paused_adds_queue.append((request, request_wave)) - return # Validate the request_id type. if not isinstance(request.request_id, str): raise TypeError( @@ -327,145 +318,13 @@ def add_request(self, request: Request, request_wave: int = 0): self.scheduler.add_request(request) - def abort_requests(self, request_ids: list[str]) -> list[str]: - """Abort requests from the scheduler. - - When running in a process with an output_queue (e.g. EngineCoreProc), - also emits abort outputs so each client receives the abort for its - request(s). Outputs are routed by client_index so the waiting client - gets the finish. - - Requests in _paused_adds_queue whose request_id is in request_ids are - added to the scheduler (and removed from the queue) so they can be - aborted and the client notified. - - Returns: - List of request IDs that were actually aborted (were in the - scheduler). - """ - request_ids_set = set(request_ids) - # Add any paused-adds that should be aborted into the scheduler, then - # remove them from the queue so they are found by finish_requests below. - if self._paused_adds_queue: - new_queue: deque[tuple[Request, int]] = deque() - while self._paused_adds_queue: - request, request_wave = self._paused_adds_queue.popleft() - if request.request_id in request_ids_set: - self.scheduler.add_request(request) - else: - new_queue.append((request, request_wave)) - self._paused_adds_queue = new_queue + def abort_requests(self, request_ids: list[str]): + """Abort requests from the scheduler.""" - # Get client_index for each request before finish_requests removes them. - client_indices = self.scheduler.get_request_client_indices(request_ids) - aborted_ids = self.scheduler.finish_requests( - request_ids, RequestStatus.FINISHED_ABORTED - ) - output_queue = getattr(self, "output_queue", None) - if aborted_ids and output_queue is not None: - # Map client_index to list of request_ids that belong to that client. - by_client: dict[int, list[str]] = {} - for rid in aborted_ids: - client_idx = client_indices[rid] - by_client.setdefault(client_idx, []).append(rid) - for client_index, rids in by_client.items(): - output_queue.put_nowait( - ( - client_index, - EngineCoreOutputs( - finished_requests=set(rids), - outputs=[ - EngineCoreOutput( - request_id=rid, - new_token_ids=[], - finish_reason=FinishReason.ABORT, - ) - for rid in rids - ], - ), - ), - ) - return aborted_ids - - def pause_scheduler( - self, - mode: PauseMode = "abort", - clear_cache: bool = True, - ) -> UtilityFuture: - """Pause generation; behavior depends on mode. - - All pause states queue new adds. PAUSE_ABORT and PAUSE_KEEP skip step(); - PAUSE_WAIT allows step() so in-flight requests can drain. - - - ``abort``: Set PAUSE_ABORT, abort all requests, wait for abort - outputs to be sent (when running with output_queue), clear caches, - then complete the returned Future. - - ``wait``: Set PAUSE_WAIT (queue adds, keep stepping); when drained, - set PAUSE_KEEP, clear caches, complete the returned Future. - - ``keep``: Set PAUSE_KEEP; return a Future that completes when the - output queue is empty. - """ - - if not hasattr(self, "_pending_step_completions"): - raise RuntimeError( - "pause_scheduler with deferrable modes requires " - "_pending_step_completions (use EngineCoreProc, not EngineCore)" - ) - if not hasattr(self, "output_queue"): - raise RuntimeError( - "pause_scheduler requires output_queue (use EngineCoreProc)" - ) - else: - output_queue = self.output_queue - - future: Future[Any] = Future() - - def _step_queue_empty() -> None: - if not output_queue.empty(): - return - future.set_result(None) - - if mode == "abort": - self._scheduler_pause_state = PauseState.PAUSE_ABORT - request_ids = self.scheduler.get_all_request_ids() - if request_ids: - self.abort_requests(request_ids) - if clear_cache: - self.reset_prefix_cache() - self.reset_mm_cache() - self.reset_encoder_cache() - return UtilityFuture(future, _step_queue_empty) - elif mode == "keep": - self._scheduler_pause_state = PauseState.PAUSE_KEEP - return UtilityFuture(future, _step_queue_empty) - - elif mode == "wait": - # wait: PAUSE_WAIT so adds are queued but step() still runs to drain. - self._scheduler_pause_state = PauseState.PAUSE_WAIT - - def _step_wait() -> None: - if self.scheduler.has_unfinished_requests(): - return - self._scheduler_pause_state = PauseState.PAUSE_KEEP - if clear_cache: - self.reset_prefix_cache() - self.reset_mm_cache() - self.reset_encoder_cache() - future.set_result(None) - - return UtilityFuture(future, _step_wait) - else: - raise ValueError(f"Invalid pause mode: {mode}") - - def resume_scheduler(self) -> None: - """Resume the scheduler and flush any requests queued while paused.""" - self._scheduler_pause_state = PauseState.UNPAUSED - while self._paused_adds_queue: - self.add_request(*self._paused_adds_queue.popleft()) - - def is_scheduler_paused(self) -> bool: - """Return whether the scheduler is in any pause state.""" - return self._scheduler_pause_state != PauseState.UNPAUSED + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that + # (i.e. client-aborted vs stop criteria met). + self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) @contextmanager def log_error_detail(self, scheduler_output: SchedulerOutput): @@ -520,13 +379,6 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: was executed. """ - # If paused (abort/keep), don't schedule any work. PAUSE_WAIT allows step. - if self._scheduler_pause_state in ( - PauseState.PAUSE_ABORT, - PauseState.PAUSE_KEEP, - ): - return {}, False - # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): @@ -577,12 +429,6 @@ def step_with_batch_queue( batch in the job queue is finished. 3. Update the scheduler from the output. """ - # If paused (abort/keep), don't schedule any work. PAUSE_WAIT allows step. - if self._scheduler_pause_state in ( - PauseState.PAUSE_ABORT, - PauseState.PAUSE_KEEP, - ): - return {}, False batch_queue = self.batch_queue assert batch_queue is not None @@ -833,8 +679,6 @@ def __init__( ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() - # Deferred utilities: engine calls .step() each loop until .future is done. - self._pending_step_completions: list[UtilityFuture] = [] executor_fail_callback = lambda: self.input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b"") ) @@ -1159,8 +1003,8 @@ def run_busy_loop(self): self._process_input_queue() # 2) Step the engine core and return the outputs. self._process_engine_step() - # 3) Run step-based completion checks (e.g. pause_scheduler futures). - self._process_pending_step_completions() + # 3) Run per-step hooks. + self._process_per_step_hooks() def _process_input_queue(self): """Exits when an engine step needs to be performed.""" @@ -1170,8 +1014,6 @@ def _process_input_queue(self): not self.engines_running and not self.scheduler.has_requests() and not self.batch_queue - and self._scheduler_pause_state == PauseState.UNPAUSED - and not self._pending_step_completions ): if self.input_queue.empty(): # Drain aborts queue; all aborts are also processed via input_queue. @@ -1211,24 +1053,12 @@ def _process_engine_step(self) -> bool: return model_executed - def _process_pending_step_completions(self) -> None: - """Run step() on deferred utilities; remove any whose future is done.""" - pending = self._pending_step_completions - if not pending: - return - still_pending: list[UtilityFuture] = [] - for uf in pending: - if uf.future.done(): - continue - try: - uf.step() - except BaseException as e: - logger.exception("Step completion check failed") - uf.future.set_exception(e) - continue - if not uf.future.done(): - still_pending.append(uf) - self._pending_step_completions = still_pending + def _process_per_step_hooks(self) -> None: + if self.per_step_hooks: + for hook in list(self.per_step_hooks): + finished = hook() + if finished: + self.per_step_hooks.discard(hook) def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any @@ -1242,30 +1072,8 @@ def _handle_client_request( self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: client_idx, call_id, method_name, args = request - try: - method = getattr(self, method_name) - result = method(*self._convert_msgspec_args(method, args)) - - if isinstance(result, UtilityFuture): - result.register_done( - self.output_queue, client_idx, call_id, method_name - ) - self._pending_step_completions.append(result) - else: - output = UtilityOutput(call_id) - output.result = UtilityResult(result) - self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output)) - ) - except BaseException as e: - logger.exception("Invocation of %s method failed", method_name) - output = UtilityOutput(call_id) - output.failure_message = ( - f"Call to {method_name} method failed: {str(e)}" - ) - self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output)) - ) + output = UtilityOutput(call_id) + self._invoke_utility_method(method_name, None, client_idx, output, args) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -1273,6 +1081,33 @@ def _handle_client_request( "Unrecognized input request type encountered: %s", request_type ) + def _invoke_utility_method( + self, + method_name: str, + method: Callable | None, + client_idx: int, + output: UtilityOutput, + args=(), + ): + try: + if method is None: + method = getattr(self, method_name) + result = method(*self._convert_msgspec_args(method, args)) + if isinstance(result, Future): + result.add_done_callback( + lambda f: self._invoke_utility_method( + method_name, f.result, client_idx, output + ) + ) + return + output.result = UtilityResult(result) + except Exception as e: + logger.exception("Invocation of %s method failed", method_name) + output.failure_message = f"Call to {method_name} method failed: {str(e)}" + self.output_queue.put_nowait( + (client_idx, EngineCoreOutputs(utility_output=output)) + ) + @staticmethod def _convert_msgspec_args(method, args): """If a provided arg type doesn't match corresponding target method @@ -1479,6 +1314,78 @@ def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None: ) ) + def pause_scheduler( + self, mode: PauseMode = "abort", clear_cache: bool = True + ) -> Future | None: + """Pause generation; behavior depends on mode. + + All pause states queue new adds. PAUSE_ABORT and PAUSE_KEEP skip step(); + PAUSE_WAIT allows step() so in-flight requests can drain. + + - ``abort``: Set PAUSE_ABORT, abort all requests, wait for abort + outputs to be sent (when running with output_queue), clear caches, + then complete the returned Future. + - ``wait``: Set PAUSE_WAIT (queue adds, keep stepping); when drained, + set PAUSE_KEEP, clear caches, complete the returned Future. + - ``keep``: Set PAUSE_KEEP; return a Future that completes when the + output queue is empty. + """ + if mode not in ("keep", "abort", "wait"): + raise ValueError(f"Invalid pause mode: {mode}") + + if mode == "keep": + # TODO could integrate with forced-preemption here for cache reset case + self.scheduler.set_pause_state(PauseState.PAUSED_ALL) + return None + + future: Future[Any] = Future() + + def _wait_for_running_to_finish() -> bool: + if not self.scheduler.has_requests(): + if clear_cache: + self.reset_prefix_cache() + self.reset_mm_cache() + self.reset_encoder_cache() + future.set_result(None) + return True + return False + + if mode == "abort": + aborted = self.scheduler.finish_requests( + None, RequestStatus.FINISHED_ABORTED + ) + if aborted: + # Map client_index to list of request_ids that belong to that client. + by_client = defaultdict[int, set[str]](set) + for req_id, client_index in aborted: + by_client[client_index].add(req_id) + for client_index, req_ids in by_client.items(): + outputs = [ + EngineCoreOutput( + request_id=rid, + new_token_ids=[], + finish_reason=FinishReason.ABORT, + ) + for rid in req_ids + ] + eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) + self.output_queue.put_nowait((client_index, eco)) + + self.scheduler.set_pause_state(PauseState.PAUSED_NEW) + if not _wait_for_running_to_finish(): + self.per_step_hooks.add(_wait_for_running_to_finish) + return future + return None + + def resume_scheduler(self) -> None: + """Resume the scheduler and flush any requests queued while paused.""" + self.scheduler.set_pause_state(PauseState.UNPAUSED) + + def is_scheduler_paused(self) -> bool: + return False # TODO + """Return whether the scheduler is in any pause state.""" + # return self._scheduler_pause_state != PauseState.UNPAUSED + class DPEngineCoreProc(EngineCoreProc): """ZMQ-wrapper for running EngineCore in background process diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4bcd5bc88973..c33fc0083f1d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -978,9 +978,7 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def pause_scheduler_async( - self, - mode: PauseMode = "abort", - clear_cache: bool = True, + self, mode: PauseMode = "abort", clear_cache: bool = True ) -> None: await self.call_utility_async("pause_scheduler", mode, clear_cache)