From e4c50d90d0b46bedc117d2d8e3d5d5a23d6d87b3 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 9 Dec 2025 15:59:23 +0800 Subject: [PATCH 1/6] offline part Signed-off-by: wang.yuqi --- vllm/config/vllm.py | 7 +++++++ vllm/entrypoints/llm.py | 5 ++++- vllm/v1/core/sched/interface.py | 3 +++ vllm/v1/core/sched/scheduler.py | 4 ++++ vllm/v1/engine/__init__.py | 5 +++++ vllm/v1/engine/core.py | 29 +++++++++++++++++++++++++++++ vllm/v1/engine/core_client.py | 10 ++++++++++ 7 files changed, 62 insertions(+), 1 deletion(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 36e4bd159dc7..5db79a7615b9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -969,6 +969,13 @@ def has_blocked_weights(): # Handle the KV connector configs self._post_init_kv_transfer_config() + from vllm.v1.engine import SchedulerReconfigure + + self.org_scheduler_config = SchedulerReconfigure( + max_num_seqs=self.scheduler_config.max_num_seqs, + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + ) + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when # enable sequence parallelism diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 913324fd5f9c..4bc9445ed825 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -76,7 +76,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter -from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine import EngineCoreRequest, SchedulerReconfigure from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor @@ -364,6 +364,9 @@ def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() + def reconfigure_scheduler(self, config: SchedulerReconfigure): + return self.llm_engine.engine_core.reconfigure_scheduler(config) + def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: self.default_sampling_params = self.model_config.get_diff_sampling_param() diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 596ab05ad320..ed4a801998d9 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -187,3 +187,6 @@ def shutdown(self) -> None: def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: return None + + def reconfigure(self, max_num_seqs: int | None, max_num_batched_tokens: int | None): + raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d858e840039c..9fb9755810d1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1749,3 +1749,7 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # Return the IDs of affected running requests to skip in # update_from_output. return sync_affected_req_ids + + def reconfigure(self, max_num_seqs: int, max_num_batched_tokens: int): + self.max_num_running_reqs = max_num_seqs + self.max_num_scheduled_tokens = max_num_batched_tokens diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108d..2fd4c61bd2cd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -212,3 +212,8 @@ class ReconfigureRankType(enum.IntEnum): KEEP_CURRENT_RANK = -1 SHUTDOWN_CURRENT_RANK = -2 + + +class SchedulerReconfigure(msgspec.Struct, array_like=True): + max_num_seqs: int | None = None + max_num_batched_tokens: int | None = None diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3d3a1e138dde..5b3132a1564b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -48,6 +48,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, + SchedulerReconfigure, UtilityOutput, UtilityResult, ) @@ -577,6 +578,34 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i self.structured_output_manager.grammar_init(req) return req, request.current_wave + def reconfigure_scheduler(self, config: SchedulerReconfigure): + vllm_config = self.vllm_config + org_scheduler_config = vllm_config.org_scheduler_config + + max_num_seqs = config.max_num_seqs + if max_num_seqs is None: + max_num_seqs = org_scheduler_config.max_num_seqs + + max_num_batched_tokens = config.max_num_batched_tokens + if max_num_batched_tokens is None: + max_num_batched_tokens = org_scheduler_config.max_num_seqs + + if max_num_seqs > org_scheduler_config.max_num_seqs: + return False + + if max_num_batched_tokens > org_scheduler_config.max_num_batched_tokens: + return False + + scheduler_config = self.vllm_config.scheduler_config + scheduler_config.max_num_seqs = config.max_num_seqs + scheduler_config.max_num_batched_tokens = config.max_num_batched_tokens + + self.scheduler.reconfigure( + max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens + ) + + return True + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c936646aa799..0a1e9b829520 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -36,6 +36,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, + SchedulerReconfigure, UtilityOutput, ) from vllm.v1.engine.coordinator import DPCoordinator @@ -178,6 +179,9 @@ def save_sharded_state( ) -> None: raise NotImplementedError + def reconfigure_scheduler(self, config: SchedulerReconfigure): + raise NotImplementedError + def collective_rpc( self, method: str | Callable[..., _R], @@ -339,6 +343,9 @@ def collective_rpc( def dp_engines_running(self) -> bool: return False + def reconfigure_scheduler(self, config: SchedulerReconfigure): + self.engine_core.reconfigure_scheduler(config) + @dataclass class BackgroundResources: @@ -804,6 +811,9 @@ def save_sharded_state( ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) + def reconfigure_scheduler(self, config: SchedulerReconfigure): + self.call_utility("reconfigure_scheduler", config) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" From 8432f093ac0064832c9aa748b15a4a48923b12d7 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 9 Dec 2025 16:25:49 +0800 Subject: [PATCH 2/6] Update vllm/v1/engine/core.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: wang.yuqi --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5b3132a1564b..d5b5f2cdffa2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -588,7 +588,7 @@ def reconfigure_scheduler(self, config: SchedulerReconfigure): max_num_batched_tokens = config.max_num_batched_tokens if max_num_batched_tokens is None: - max_num_batched_tokens = org_scheduler_config.max_num_seqs + max_num_batched_tokens = org_scheduler_config.max_num_batched_tokens if max_num_seqs > org_scheduler_config.max_num_seqs: return False From fb3019fb10eacee5d26cf2ae3c0e65e960438be8 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 9 Dec 2025 16:40:40 +0800 Subject: [PATCH 3/6] + comment Signed-off-by: wang.yuqi --- vllm/v1/engine/core.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d5b5f2cdffa2..e02e0998dd4f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -582,17 +582,15 @@ def reconfigure_scheduler(self, config: SchedulerReconfigure): vllm_config = self.vllm_config org_scheduler_config = vllm_config.org_scheduler_config - max_num_seqs = config.max_num_seqs - if max_num_seqs is None: - max_num_seqs = org_scheduler_config.max_num_seqs - - max_num_batched_tokens = config.max_num_batched_tokens - if max_num_batched_tokens is None: - max_num_batched_tokens = org_scheduler_config.max_num_batched_tokens + max_num_seqs = config.max_num_seqs or org_scheduler_config.max_num_seqs + max_num_batched_tokens = ( + config.max_num_batched_tokens or org_scheduler_config.max_num_batched_tokens + ) + # The reconfigured values can only be less than or equal to their original + # values. Otherwise, it may lead to an OOM, or CUDA graphs are not covered. if max_num_seqs > org_scheduler_config.max_num_seqs: return False - if max_num_batched_tokens > org_scheduler_config.max_num_batched_tokens: return False From 4be915414510c8367131df6b4d95b29e626d7b40 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 11 Dec 2025 11:34:16 +0800 Subject: [PATCH 4/6] - SchedulerReconfigure Signed-off-by: wang.yuqi --- vllm/config/vllm.py | 9 +++------ vllm/entrypoints/llm.py | 10 +++++++--- vllm/v1/core/sched/interface.py | 2 +- vllm/v1/core/sched/scheduler.py | 11 ++++------- vllm/v1/engine/__init__.py | 5 ----- vllm/v1/engine/core.py | 19 +++++++++---------- vllm/v1/engine/core_client.py | 17 +++++++++++------ 7 files changed, 35 insertions(+), 38 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 296761e8ac83..64c4117dc16c 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -982,12 +982,9 @@ def has_blocked_weights(): # Handle the KV connector configs self._post_init_kv_transfer_config() - from vllm.v1.engine import SchedulerReconfigure - - self.org_scheduler_config = SchedulerReconfigure( - max_num_seqs=self.scheduler_config.max_num_seqs, - max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, - ) + # for online reconfigure + self.org_max_num_seqs = self.scheduler_config.max_num_seqs + self.org_max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fddf98444da0..c2fe39890f4a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -76,7 +76,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter -from vllm.v1.engine import EngineCoreRequest, SchedulerReconfigure +from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor @@ -370,8 +370,12 @@ def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() - def reconfigure_scheduler(self, config: SchedulerReconfigure): - return self.llm_engine.engine_core.reconfigure_scheduler(config) + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + return self.llm_engine.engine_core.reconfigure( + max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens + ) def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index ed4a801998d9..a78fffe9c70b 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -188,5 +188,5 @@ def shutdown(self) -> None: def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: return None - def reconfigure(self, max_num_seqs: int | None, max_num_batched_tokens: int | None): + def reconfigure(self, max_num_seqs: int, max_num_batched_tokens: int) -> None: raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1fd9e941476a..960e4d8831a5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1799,14 +1799,11 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: total_failed_tokens, ) - # Return the IDs of affected running requests to skip in - # update_from_output. - return sync_affected_req_ids - - def reconfigure(self, max_num_seqs: int, max_num_batched_tokens: int): - self.max_num_running_reqs = max_num_seqs - self.max_num_scheduled_tokens = max_num_batched_tokens # Mark async requests with KV load failures for retry once loading completes self.failed_recving_kv_req_ids |= async_failed_req_ids # Return sync affected IDs to skip in update_from_output return sync_failed_req_ids + + def reconfigure(self, max_num_seqs: int, max_num_batched_tokens: int) -> None: + self.max_num_running_reqs = max_num_seqs + self.max_num_scheduled_tokens = max_num_batched_tokens diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 7486672c6171..4f54d12f4b8d 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -215,8 +215,3 @@ class ReconfigureRankType(enum.IntEnum): KEEP_CURRENT_RANK = -1 SHUTDOWN_CURRENT_RANK = -2 - - -class SchedulerReconfigure(msgspec.Struct, array_like=True): - max_num_seqs: int | None = None - max_num_batched_tokens: int | None = None diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3b955b7198c5..80a54dd9bebb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -48,7 +48,6 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, - SchedulerReconfigure, UtilityOutput, UtilityResult, ) @@ -581,25 +580,25 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i self.structured_output_manager.grammar_init(req) return req, request.current_wave - def reconfigure_scheduler(self, config: SchedulerReconfigure): + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: vllm_config = self.vllm_config - org_scheduler_config = vllm_config.org_scheduler_config - - max_num_seqs = config.max_num_seqs or org_scheduler_config.max_num_seqs + max_num_seqs = max_num_seqs or vllm_config.org_max_num_seqs max_num_batched_tokens = ( - config.max_num_batched_tokens or org_scheduler_config.max_num_batched_tokens + max_num_batched_tokens or vllm_config.org_max_num_batched_tokens ) # The reconfigured values can only be less than or equal to their original # values. Otherwise, it may lead to an OOM, or CUDA graphs are not covered. - if max_num_seqs > org_scheduler_config.max_num_seqs: + if max_num_seqs > vllm_config.org_max_num_seqs: return False - if max_num_batched_tokens > org_scheduler_config.max_num_batched_tokens: + if max_num_batched_tokens > vllm_config.org_max_num_batched_tokens: return False scheduler_config = self.vllm_config.scheduler_config - scheduler_config.max_num_seqs = config.max_num_seqs - scheduler_config.max_num_batched_tokens = config.max_num_batched_tokens + scheduler_config.max_num_seqs = max_num_seqs + scheduler_config.max_num_batched_tokens = max_num_batched_tokens self.scheduler.reconfigure( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0a1e9b829520..b576084df350 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -36,7 +36,6 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, - SchedulerReconfigure, UtilityOutput, ) from vllm.v1.engine.coordinator import DPCoordinator @@ -179,7 +178,9 @@ def save_sharded_state( ) -> None: raise NotImplementedError - def reconfigure_scheduler(self, config: SchedulerReconfigure): + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: raise NotImplementedError def collective_rpc( @@ -343,8 +344,10 @@ def collective_rpc( def dp_engines_running(self) -> bool: return False - def reconfigure_scheduler(self, config: SchedulerReconfigure): - self.engine_core.reconfigure_scheduler(config) + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + return self.engine_core.reconfigure(max_num_seqs, max_num_batched_tokens) @dataclass @@ -811,8 +814,10 @@ def save_sharded_state( ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) - def reconfigure_scheduler(self, config: SchedulerReconfigure): - self.call_utility("reconfigure_scheduler", config) + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + return self.call_utility("reconfigure", max_num_seqs, max_num_batched_tokens) class AsyncMPClient(MPClient): From df83ded95d01159eb292ab4421deca2ada27da9d Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 11 Dec 2025 15:33:45 +0800 Subject: [PATCH 5/6] + online part Signed-off-by: wang.yuqi --- vllm/engine/protocol.py | 8 ++++++++ vllm/entrypoints/serve/__init__.py | 6 +++--- vllm/entrypoints/serve/{sleep => dev}/__init__.py | 0 vllm/entrypoints/serve/{sleep => dev}/api_router.py | 10 ++++++++++ vllm/entrypoints/serve/dev/protocol.py | 9 +++++++++ vllm/v1/engine/async_llm.py | 12 ++++++++++++ vllm/v1/engine/core_client.py | 12 ++++++++++++ 7 files changed, 54 insertions(+), 3 deletions(-) rename vllm/entrypoints/serve/{sleep => dev}/__init__.py (100%) rename vllm/entrypoints/serve/{sleep => dev}/api_router.py (82%) create mode 100644 vllm/entrypoints/serve/dev/protocol.py diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d94951a0cffc..ba200ba50bba 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -188,3 +188,11 @@ async def collective_rpc( async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: """Get supported tasks""" raise NotImplementedError + + def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + raise NotImplementedError + + async def reconfigure_async(self, max_num_seqs: int | None, max_num_batched_tokens: int | None): + raise NotImplementedError diff --git a/vllm/entrypoints/serve/__init__.py b/vllm/entrypoints/serve/__init__.py index c4fcc92db931..90fa1eb4af64 100644 --- a/vllm/entrypoints/serve/__init__.py +++ b/vllm/entrypoints/serve/__init__.py @@ -23,11 +23,11 @@ def register_vllm_serve_api_routers(app: FastAPI): attach_profile_router(app) - from vllm.entrypoints.serve.sleep.api_router import ( - attach_router as attach_sleep_router, + from vllm.entrypoints.serve.dev.api_router import ( + attach_router as attach_dev_router, ) - attach_sleep_router(app) + attach_dev_router(app) from vllm.entrypoints.serve.tokenize.api_router import ( attach_router as attach_tokenize_router, diff --git a/vllm/entrypoints/serve/sleep/__init__.py b/vllm/entrypoints/serve/dev/__init__.py similarity index 100% rename from vllm/entrypoints/serve/sleep/__init__.py rename to vllm/entrypoints/serve/dev/__init__.py diff --git a/vllm/entrypoints/serve/sleep/api_router.py b/vllm/entrypoints/serve/dev/api_router.py similarity index 82% rename from vllm/entrypoints/serve/sleep/api_router.py rename to vllm/entrypoints/serve/dev/api_router.py index bc01e185315c..83616fcc82aa 100644 --- a/vllm/entrypoints/serve/sleep/api_router.py +++ b/vllm/entrypoints/serve/dev/api_router.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.serve.dev.protocol import ReconfigureRequest from vllm.logger import init_logger logger = init_logger(__name__) @@ -49,6 +50,15 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) +@router.post("/reconfigure") +async def reconfigure(request: ReconfigureRequest, raw_request: Request): + success = await engine_client(raw_request).reconfigure_async( + max_num_seqs=request.max_num_seqs, + max_num_batched_tokens=request.max_num_batched_tokens, + ) + return JSONResponse(content={"success": success}) + + def attach_router(app: FastAPI): if not envs.VLLM_SERVER_DEV_MODE: return diff --git a/vllm/entrypoints/serve/dev/protocol.py b/vllm/entrypoints/serve/dev/protocol.py new file mode 100644 index 000000000000..826fe698bf13 --- /dev/null +++ b/vllm/entrypoints/serve/dev/protocol.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pydantic import BaseModel + + +class ReconfigureRequest(BaseModel): + max_num_seqs: int | None = None + max_num_batched_tokens: int | None = None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 8eff61563cce..dc8b29f7c69b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -848,6 +848,18 @@ async def scale_elastic_ep( custom_stat_loggers=None, ) + async def reconfigure( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + return self.engine_core.reconfigure(max_num_seqs, max_num_batched_tokens) + + async def reconfigure_async( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ): + return await self.engine_core.reconfigure_async( + max_num_seqs, max_num_batched_tokens + ) + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b576084df350..2ea38cd2a89f 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -258,6 +258,11 @@ async def collective_rpc_async( ) -> list[_R]: raise NotImplementedError + async def reconfigure_async( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -1029,6 +1034,13 @@ async def collective_rpc_async( "collective_rpc", method, timeout, args, kwargs ) + async def reconfigure_async( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ) -> bool: + return await self.call_utility_async( + "reconfigure", max_num_seqs, max_num_batched_tokens + ) + class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) From 70b5c0bada9096026ec73c38f2bd1a5a83348b70 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 11 Dec 2025 17:06:45 +0800 Subject: [PATCH 6/6] mypy Signed-off-by: wang.yuqi --- vllm/engine/protocol.py | 4 +++- vllm/v1/engine/async_llm.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index ba200ba50bba..d3ca35ac2df1 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -194,5 +194,7 @@ def reconfigure( ) -> bool: raise NotImplementedError - async def reconfigure_async(self, max_num_seqs: int | None, max_num_batched_tokens: int | None): + async def reconfigure_async( + self, max_num_seqs: int | None, max_num_batched_tokens: int | None + ): raise NotImplementedError diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index dc8b29f7c69b..d9054eb5d6aa 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -848,14 +848,14 @@ async def scale_elastic_ep( custom_stat_loggers=None, ) - async def reconfigure( + def reconfigure( self, max_num_seqs: int | None, max_num_batched_tokens: int | None ) -> bool: return self.engine_core.reconfigure(max_num_seqs, max_num_batched_tokens) async def reconfigure_async( self, max_num_seqs: int | None, max_num_batched_tokens: int | None - ): + ) -> bool: return await self.engine_core.reconfigure_async( max_num_seqs, max_num_batched_tokens )