From 91a5ae188fc8396cb2ed1e87e04c05f9f7c9d626 Mon Sep 17 00:00:00 2001 From: Yanzhao Wang Date: Sat, 9 May 2026 11:53:25 +0800 Subject: [PATCH] Add layerwise load-ahead window --- .../user-guide/prefix-cache/pipeline_store.md | 10 +- examples/ucm_config_example.yaml | 4 + .../vllm/tests/test_layerwise_load_ahead.py | 352 ++++++++++++++++++ ucm/integration/vllm/ucm_connector.py | 62 ++- 4 files changed, 417 insertions(+), 11 deletions(-) create mode 100644 ucm/integration/vllm/tests/test_layerwise_load_ahead.py diff --git a/docs/source/user-guide/prefix-cache/pipeline_store.md b/docs/source/user-guide/prefix-cache/pipeline_store.md index 8960daeb7..442b499d2 100644 --- a/docs/source/user-guide/prefix-cache/pipeline_store.md +++ b/docs/source/user-guide/prefix-cache/pipeline_store.md @@ -161,8 +161,16 @@ ucm_connectors: store_pipeline: "Cache|Posix" storage_backends: "/mnt/test" use_layerwise: true +layerwise_load_ahead: 1 ``` +`layerwise_load_ahead` controls how many local layers are submitted for KV load +before inference reaches them. The default value `1` preserves the previous +one-layer lookahead behavior. If layerwise load latency is higher than per-layer +inference latency, try `2` or `4` to improve overlap. Larger values can increase +Cache Store queue, host buffer, and H2D stream pressure, so tune this value with +your model and storage backend. + **⚠️ Make sure to replace `"/vllm-workspace/unified-cache-management/examples/ucm_config_example.yaml"` with your actual config file path.** If you see log as below: @@ -262,4 +270,4 @@ This log indicates that the **Posix Store** has received a **load or dump task** ```text [UC][D] Posix task({task_id},{operation},{subtask_number},{size}) finished, cost {time}ms. [PID,TID] ``` -This log indicates that a load or dump task in the **Posix Store** has completed, along with its execution time in **in ms**. \ No newline at end of file +This log indicates that a load or dump task in the **Posix Store** has completed, along with its execution time in **in ms**. diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml index f9deb9ae2..b51fe83b6 100644 --- a/examples/ucm_config_example.yaml +++ b/examples/ucm_config_example.yaml @@ -31,6 +31,10 @@ enable_event_sync: true # Whether to use layerwise loading/saving (optional, default: True for UCMConnector) use_layerwise: true +# Number of local layers to submit for layerwise KV load ahead of inference. +# 1 preserves the default one-layer lookahead behavior. Try 2 or 4 when load +# latency is higher than per-layer inference latency. +layerwise_load_ahead: 1 # hit_ratio: 0.9 # Whether to record requests' traces (optional, default: False) diff --git a/ucm/integration/vllm/tests/test_layerwise_load_ahead.py b/ucm/integration/vllm/tests/test_layerwise_load_ahead.py new file mode 100644 index 000000000..b29c1550f --- /dev/null +++ b/ucm/integration/vllm/tests/test_layerwise_load_ahead.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from collections import defaultdict +from pathlib import Path + +_MISSING = object() +_STUBBED_MODULE_NAMES = [ + "torch", + "numpy", + "vllm", + "vllm.config", + "vllm.distributed", + "vllm.distributed.kv_transfer", + "vllm.distributed.kv_transfer.kv_connector", + "vllm.distributed.kv_transfer.kv_connector.v1", + "vllm.distributed.kv_transfer.kv_connector.v1.base", + "vllm.distributed.parallel_state", + "vllm.model_executor", + "vllm.model_executor.models", + "vllm.model_executor.models.utils", + "vllm.platforms", + "vllm.v1", + "vllm.v1.core", + "vllm.v1.core.sched", + "vllm.v1.core.sched.output", + "ucm", + "ucm.integration", + "ucm.integration.vllm", + "ucm.integration.vllm.device", + "ucm.logger", + "ucm.observability", + "ucm.shared", + "ucm.shared.metrics", + "ucm.store", + "ucm.store.factory_v1", + "ucm.store.ucmstore_v1", + "ucm.utils", + "ucm.sparse", + "ucm.sparse.state", +] + + +def _module(name: str, *, package: bool = False) -> types.ModuleType: + module = types.ModuleType(name) + if package: + module.__path__ = [] + sys.modules[name] = module + return module + + +def _install_dependency_stubs() -> dict[str, object]: + previous_modules = { + name: sys.modules.get(name, _MISSING) for name in _STUBBED_MODULE_NAMES + } + for name in _STUBBED_MODULE_NAMES: + sys.modules.pop(name, None) + + torch = _module("torch") + torch.Tensor = type("Tensor", (), {}) + torch.dtype = type("dtype", (), {}) + + numpy = _module("numpy") + numpy.ndarray = list + numpy.uint64 = int + numpy.asarray = lambda value, dtype=None: value + numpy.ascontiguousarray = lambda value: value + + _module("vllm", package=True) + config = _module("vllm.config") + config.VllmConfig = type("VllmConfig", (), {}) + + _module("vllm.distributed", package=True) + _module("vllm.distributed.kv_transfer", package=True) + _module("vllm.distributed.kv_transfer.kv_connector", package=True) + _module("vllm.distributed.kv_transfer.kv_connector.v1", package=True) + base = _module("vllm.distributed.kv_transfer.kv_connector.v1.base") + + class KVConnectorBase_V1: + def __init__(self, vllm_config=None, role=None): + self._vllm_config = vllm_config + self._role = role + self._connector_metadata = None + + def _get_connector_metadata(self): + return self._connector_metadata + + def bind_connector_metadata(self, connector_metadata): + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self): + self._connector_metadata = None + + base.KVConnectorBase_V1 = KVConnectorBase_V1 + base.KVConnectorMetadata = type("KVConnectorMetadata", (), {}) + base.KVConnectorRole = types.SimpleNamespace(SCHEDULER="scheduler", WORKER="worker") + + parallel_state = _module("vllm.distributed.parallel_state") + parallel_state.get_world_group = lambda: types.SimpleNamespace(local_rank=0, rank=0) + + _module("vllm.model_executor", package=True) + _module("vllm.model_executor.models", package=True) + model_utils = _module("vllm.model_executor.models.utils") + model_utils.extract_layer_index = lambda name: int(name.rsplit(".", 1)[1]) + + platforms = _module("vllm.platforms") + platforms.current_platform = types.SimpleNamespace( + is_cuda_alike=lambda: False, + device_type="cpu", + ) + + _module("vllm.v1", package=True) + _module("vllm.v1.core", package=True) + _module("vllm.v1.core.sched", package=True) + sched_output = _module("vllm.v1.core.sched.output") + sched_output.SchedulerOutput = type("SchedulerOutput", (), {}) + + _module("ucm", package=True) + _module("ucm.integration", package=True) + _module("ucm.integration.vllm", package=True) + device = _module("ucm.integration.vllm.device") + device.create_device = lambda: None + + logger_module = _module("ucm.logger") + + class DummyLogger: + def info(self, *args, **kwargs): + pass + + def info_once(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + def error(self, *args, **kwargs): + pass + + logger_module.init_logger = lambda name=None: DummyLogger() + + observability = _module("ucm.observability") + observability.PrometheusStatsLogger = type("PrometheusStatsLogger", (), {}) + + _module("ucm.shared", package=True) + metrics = _module("ucm.shared.metrics") + metrics.ucmmetrics = types.SimpleNamespace() + + _module("ucm.store", package=True) + factory = _module("ucm.store.factory_v1") + factory.UcmConnectorFactoryV1 = type("UcmConnectorFactoryV1", (), {}) + store_base = _module("ucm.store.ucmstore_v1") + store_base.Task = type("Task", (), {}) + store_base.UcmKVStoreBaseV1 = type("UcmKVStoreBaseV1", (), {}) + + utils = _module("ucm.utils") + utils.Config = type("Config", (), {}) + + _module("ucm.sparse", package=True) + sparse_state = _module("ucm.sparse.state") + sparse_state.has_ucm_sparse = lambda: False + return previous_modules + + +def _restore_dependency_stubs(previous_modules: dict[str, object]) -> None: + for name in _STUBBED_MODULE_NAMES: + previous_module = previous_modules[name] + if previous_module is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous_module + + +def _load_connector_module(): + previous_modules = _install_dependency_stubs() + module_name = "ucm.integration.vllm.ucm_connector" + previous_connector_module = sys.modules.get(module_name, _MISSING) + try: + sys.modules.pop(module_name, None) + module_path = ( + Path(__file__).resolve().parents[4] + / "ucm" + / "integration" + / "vllm" + / "ucm_connector.py" + ) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + if previous_connector_module is _MISSING: + sys.modules.pop(module_name, None) + else: + sys.modules[module_name] = previous_connector_module + _restore_dependency_stubs(previous_modules) + + +class FakeTask: + def __init__(self, layer_id: int, block_ids: tuple[bytes, ...]): + self.layer_id = layer_id + self.block_ids = block_ids + + +class FakeStore: + def __init__(self, fail_block_ids: set[bytes] | None = None): + self.fail_block_ids = fail_block_ids or set() + self.load_attempts = [] + self.loads = [] + self.waited_layers = [] + + def load_data(self, block_ids, shard_indexes, layer_ptrs): + layer_id = shard_indexes[0] + block_ids_tuple = tuple(block_ids) + self.load_attempts.append((layer_id, block_ids_tuple)) + if self.fail_block_ids.intersection(block_ids_tuple): + raise RuntimeError("load submit failed") + + task = FakeTask(layer_id, block_ids_tuple) + self.loads.append( + { + "layer_id": layer_id, + "block_ids": block_ids_tuple, + "ptrs": list(layer_ptrs), + "task": task, + } + ) + return task + + def wait(self, task): + self.waited_layers.append(task.layer_id) + + +class FakeKVCacheLayout: + def __init__(self, layer_count: int): + self.layer_count = layer_count + + def extract_block_addrs(self, vllm_block_ids, layer_first=False): + assert layer_first is True + rows = [] + for local_row in range(self.layer_count): + rows.append([[local_row * 1000 + block_id] for block_id in vllm_block_ids]) + return rows + + +def _make_metadata(*, include_failing_request: bool = False): + request_meta = { + "ok": types.SimpleNamespace( + load_block_ids=([b"ok-block"], [7]), + dump_block_ids=([], []), + ) + } + if include_failing_request: + request_meta["bad"] = types.SimpleNamespace( + load_block_ids=([b"bad-block"], [99]), + dump_block_ids=([], []), + ) + return types.SimpleNamespace(request_meta=request_meta) + + +def _make_connector(module, *, load_ahead: int, store: FakeStore | None = None): + connector = module.UCMLayerWiseConnector.__new__(module.UCMLayerWiseConnector) + connector.load_tasks = defaultdict(dict) + connector.store = store or FakeStore() + connector.request_data = [] + connector._failure_req_ids = set() + connector._invalid_block_ids = set() + connector.layerwise_load_ahead = load_ahead + connector.layer_ids = [10, 11, 12, 13, 14] + connector.layer_name_to_id = { + f"layer.{layer_id}": layer_id for layer_id in connector.layer_ids + } + connector.first_layer_id = 10 + connector.kv_cache_layout = FakeKVCacheLayout(len(connector.layer_ids)) + connector.tp_rank = 0 + connector.tp_size = 1 + connector.is_mla = False + connector.request_hasher = lambda block_id: block_id + connector.need_load = False + connector._connector_metadata = _make_metadata() + return connector + + +def _loaded_layers(store: FakeStore) -> list[int]: + return [load["layer_id"] for load in store.loads] + + +def test_layerwise_load_ahead_one_preserves_single_layer_submission(): + module = _load_connector_module() + store = FakeStore() + connector = _make_connector(module, load_ahead=1, store=store) + + connector.start_load_kv(None) + assert _loaded_layers(store) == [10] + + connector.wait_for_layer_load("layer.10") + assert store.waited_layers == [10] + assert _loaded_layers(store) == [10, 11] + + +def test_layerwise_load_ahead_prefetches_window_and_refills_by_layer_order(): + module = _load_connector_module() + store = FakeStore() + connector = _make_connector(module, load_ahead=3, store=store) + + connector.start_load_kv(None) + assert _loaded_layers(store) == [10, 11, 12] + + connector.wait_for_layer_load("layer.10") + assert store.waited_layers == [10] + assert _loaded_layers(store) == [10, 11, 12, 13] + + connector.wait_for_layer_load("layer.11") + assert store.waited_layers == [10, 11] + assert _loaded_layers(store) == [10, 11, 12, 13, 14] + + for layer_id in [12, 13, 14]: + connector.wait_for_layer_load(f"layer.{layer_id}") + + assert store.waited_layers == [10, 11, 12, 13, 14] + assert _loaded_layers(store) == [10, 11, 12, 13, 14] + + +def test_layerwise_load_ahead_skips_failed_request_in_future_layers(): + module = _load_connector_module() + store = FakeStore(fail_block_ids={b"bad-block"}) + connector = _make_connector(module, load_ahead=3, store=store) + connector._connector_metadata = _make_metadata(include_failing_request=True) + + connector.start_load_kv(None) + + assert connector._failure_req_ids == {"bad"} + assert connector._invalid_block_ids == {99} + assert _loaded_layers(store) == [10, 11, 12] + assert [ + layer_id + for layer_id, block_ids in store.load_attempts + if block_ids == (b"bad-block",) + ] == [10] + + connector.wait_for_layer_load("layer.10") + connector.wait_for_layer_load("layer.11") + + assert _loaded_layers(store) == [10, 11, 12, 13, 14] + assert [ + layer_id + for layer_id, block_ids in store.load_attempts + if block_ids == (b"bad-block",) + ] == [10] diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index e89fc1f4b..0de864eb2 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -370,6 +370,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_data_size = self.kv_cache_layout.block_size self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id self.layer_ids = sorted(set(self.layer_name_to_id.values())) + self.layer_id_to_local_row = { + layer_id: local_row for local_row, layer_id in enumerate(self.layer_ids) + } self.first_layer_id = self.layer_ids[0] self.device = create_device() @@ -780,7 +783,47 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.dump_total_ptrs: np.ndarray | None = None self.request_data: list[tuple[str, list, np.ndarray]] = [] self._failure_req_ids: set[str] = set() - logger.info("Init UCMLayerWiseConnector.") + self.layerwise_load_ahead = self._get_layerwise_load_ahead() + self._next_load_layer_index = 0 + self._submitted_load_layers: set[int] = set() + self._waited_load_layers: set[int] = set() + logger.info( + f"Init UCMLayerWiseConnector with layerwise_load_ahead={self.layerwise_load_ahead}." + ) + + def _get_layerwise_load_ahead(self) -> int: + raw_load_ahead = self.launch_config.get("layerwise_load_ahead", 1) + try: + load_ahead = int(raw_load_ahead) + except (TypeError, ValueError) as exc: + raise ValueError( + "layerwise_load_ahead must be a positive integer." + ) from exc + if load_ahead < 1: + raise ValueError("layerwise_load_ahead must be a positive integer.") + return load_ahead + + def _get_layer_local_row(self, layer_id: int) -> int: + layer_id_to_local_row = getattr(self, "layer_id_to_local_row", None) + if layer_id_to_local_row is not None: + return layer_id_to_local_row[layer_id] + return self.layer_ids.index(layer_id) + + def _submit_next_load_layers( + self, metadata: "UCMConnectorMetadata", count: int + ) -> None: + submitted_count = 0 + while submitted_count < count and self._next_load_layer_index < len( + self.layer_ids + ): + layer_id = self.layer_ids[self._next_load_layer_index] + self._next_load_layer_index += 1 + if layer_id in self._submitted_load_layers: + continue + self._submitted_load_layers.add(layer_id) + local_row = self._get_layer_local_row(layer_id) + self._submit_request_load_tasks_for_layer(layer_id, local_row, metadata) + submitted_count += 1 def _submit_request_load_tasks_for_layer( self, @@ -810,6 +853,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.load_tasks.clear() self.request_data.clear() self._failure_req_ids.clear() + self._submitted_load_layers = set() + self._waited_load_layers = set() + self._next_load_layer_index = 0 self.need_load = False for request_id, request in metadata.request_meta.items(): @@ -827,7 +873,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.request_data.append((request_id, ucm_block_ids, total_ptrs)) if self.need_load: - self._submit_request_load_tasks_for_layer(self.first_layer_id, 0, metadata) + self._submit_next_load_layers(metadata, self.layerwise_load_ahead) def wait_for_layer_load(self, layer_name: str) -> None: if not self._connector_metadata: @@ -836,6 +882,8 @@ def wait_for_layer_load(self, layer_name: str) -> None: return metadata = self._get_connector_metadata() current_layer_id = self.layer_name_to_id[layer_name] + should_refill_window = current_layer_id not in self._waited_load_layers + self._waited_load_layers.add(current_layer_id) # Pop before wait so MTP / rollback paths that revisit the same layer_name # do not call store.wait() again on already-completed handles. @@ -852,14 +900,8 @@ def wait_for_layer_load(self, layer_name: str) -> None: ) self._failure_req_ids.add(request_id) - next_layer_id = current_layer_id + 1 - if next_layer_id not in self.layer_ids: - return - next_local_row = next_layer_id - self.first_layer_id - - self._submit_request_load_tasks_for_layer( - next_layer_id, next_local_row, metadata - ) + if should_refill_window: + self._submit_next_load_layers(metadata, 1) def save_kv_layer( self,