Skip to content

Commit 9d41318

Browse files
hlin99deng451e
andauthored
feat(kv_cache): enable asymmetric store/retrieve storages in PD backend (LMCache#2509)
* feat(kv_cache): enable asymmetric save/remote storage in PD backend Remove the restriction that prevented using `save_decode_cache` and `remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios. This change introduces `pd_retrieve_locations` and `pd_store_location` parameters to decouple the KV cache retrieval and storage logic. This enables an asymmetric cache flow: 1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend. 2. Decode nodes write back their generated KV cache to a remote backend for subsequent prefill reuse. 3. In multi-turn dialogue scenarios, subsequent prefill requests retrieve historical KV cache from the remote backend, significantly increasing Prefix Cache hit rates and reducing TTFT This decoupling provides greater flexibility for cross-instance cache management and improves overall pipeline efficiency in distributed inference. [ Compute Layer ] +----------------------+ +------------------+ | Prefill Node | ===============>| Decode Node | | (Hit-Remote & GenKV) | (1) PDBackend | (Hit-PD & GenKV) | +-------^--------------+ +-------+----------+ | | : : ------------|-----------------------------------|------------ [ Storage Layer ] | | | (2) pd_store_location | (3) pd_retrieve_locations | (Decode -> Pool) | (Pool -> Prefill) | | v +-------+--------------------------------------------+ | Distributed Storage Pool | | [Node A] [Node B] [Node C] [Node D] | | <======= (Object Storage / NFS / DFS) =======> | +----------------------------------------------------+ Workflow: 1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn. 2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence. 3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote, drastically increasing Prefix Cache hit rate for multi-turn dialogues. Signed-off-by: Tony Lin <tony.lin@intel.com> * small refactor Signed-off-by: Tony Lin <tony.lin@intel.com> * config examples for pd + remote backends Signed-off-by: Tony Lin <tony.lin@intel.com> * refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location Remove the PD-specific prefix to make the retrieve/store locations generic instead of being limited to PD only. This breaks the PD-only feature restriction and allows the mechanism to be reused by other roles/components. Signed-off-by: Tony Lin <tony.lin@intel.com> * move retrieve & store locations from storage manger to cache engine Signed-off-by: Tony Lin <tony.lin@intel.com> * add para validation check Signed-off-by: Tony Lin <tony.lin@intel.com> * config: replace hardcoded IP with placeholder in decoder remote configs Signed-off-by: Tony Lin <tony.lin@intel.com> * resolve conflicts and rebase to the latest Signed-off-by: Tony Lin <tony.lin@intel.com> * address review comments Signed-off-by: Tony Lin <tony.lin@intel.com> * add description in configurations.rst Signed-off-by: Tony Lin <tony.lin@intel.com> --------- Signed-off-by: Tony Lin <tony.lin@intel.com> Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
1 parent d6661f1 commit 9d41318

8 files changed

Lines changed: 149 additions & 11 deletions

docs/source/api_reference/configurations.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ Basic cache settings that control the core functionality of LMCache.
7979
* - min_retrieve_tokens
8080
- LMCACHE_MIN_RETRIEVE_TOKENS
8181
- Minimum number of hit tokens required to perform retrieve. If hit tokens < this value, skip retrieve but still record the hits to avoid re-storing existing chunks. See :ref:`performance_tuning` for a working example. Default: 0 (disabled)
82+
* - store_location
83+
- LMCACHE_STORE_LOCATION
84+
- A single storage backend name to store KV caches into. When specified, only the matching backend receives store operations. Valid values are the backend class names registered in the storage manager, including: ``"LocalCPUBackend"``, ``"LocalDiskBackend"``, ``"RemoteBackend"``, ``"PDBackend"``, ``"P2PBackend"``, ``"GdsBackend"``, etc, and any storage plugin backends. Note: ``"PDBackend"`` cannot be used as a store location for a decoder instance in a PD setup, since PDBackend is one-way from prefiller to decoder only. Default: null (store to all active backends)
85+
* - retrieve_locations
86+
- LMCACHE_RETRIEVE_LOCATIONS
87+
- List of storage backend names to search when retrieving or looking up KV caches. When specified, only the listed backends are searched. Valid values are the backend class names registered in the storage manager, including: ``"LocalCPUBackend"``, ``"LocalDiskBackend"``, ``"RemoteBackend"``, ``"PDBackend"``, ``"P2PBackend"``, ``"GdsBackend"``, etc, and any storage plugin backends. Default: null (search all active backends)
8288
* - extra_config
8389
- LMCACHE_EXTRA_CONFIG={"key": value, ...}
8490
- Additional configuration as JSON dict. For NUMA manual mode, include "gpu_to_numa_mapping": {gpu_id: numa_node, ...}. Default: {}
@@ -475,4 +481,4 @@ These configurations are deprecated and may be removed in future versions.
475481
* - audit_actual_remote_url
476482
- LMCACHE_AUDIT_ACTUAL_REMOTE_URL
477483
- (Deprecated) URL of actual remote LMCache instance for auditing. Use extra_config['audit_actual_remote_url'] instead
478-
484+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
local_cpu: True
2+
max_local_cpu_size: 5
3+
4+
remote_url: "lm://localhost:6800"
5+
remote_serde: "cachegen"
6+
7+
retrieve_locations: ["PDBackend"]
8+
store_location: "RemoteBackend"
9+
10+
enable_pd: True
11+
transfer_channel: "nixl"
12+
pd_role: "receiver"
13+
pd_peer_host: "localhost"
14+
pd_peer_init_port: 7300
15+
pd_peer_alloc_port: 7400
16+
pd_buffer_size: 2147483648 # 2GB
17+
pd_buffer_device: "cuda"
18+
nixl_backends: [UCX]
19+
20+
save_decode_cache: true
21+
save_unfull_chunk: true
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
local_cpu: True
2+
max_local_cpu_size: 5
3+
4+
remote_url: "lm://localhost:6800"
5+
remote_serde: "cachegen"
6+
7+
retrieve_locations: ["LocalCPUBackend", "RemoteBackend"]
8+
9+
enable_pd: True
10+
transfer_channel: "nixl"
11+
pd_role: "sender"
12+
pd_proxy_host: "localhost"
13+
pd_proxy_port: 7500
14+
pd_buffer_size: 1073741824 # 1GB
15+
pd_buffer_device: "cuda"
16+
nixl_backends: [UCX]
17+
18+
save_unfull_chunk: true
19+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
local_cpu: True
2+
max_local_cpu_size: 5
3+
4+
remote_url: "lm://<your remote server IP>:<port>"
5+
remote_serde: "cachegen"
6+
7+
retrieve_locations: ["PDBackend"]
8+
store_location: "RemoteBackend"
9+
10+
enable_pd: True
11+
transfer_channel: "nixl"
12+
pd_role: "receiver"
13+
pd_peer_host: "localhost"
14+
pd_peer_init_port: 7300
15+
pd_peer_alloc_port: 7400
16+
pd_buffer_size: 2147483648 # 2GB
17+
pd_buffer_device: "cuda"
18+
nixl_backends: [UCX]
19+
20+
save_decode_cache: true
21+
save_unfull_chunk: true
22+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
local_cpu: True
2+
max_local_cpu_size: 5
3+
4+
remote_url: "lm://<your remote server IP>:<port>"
5+
remote_serde: "cachegen"
6+
7+
retrieve_locations: ["PDBackend"]
8+
store_location: "RemoteBackend"
9+
10+
enable_pd: True
11+
transfer_channel: "nixl"
12+
pd_role: "receiver"
13+
pd_peer_host: "localhost"
14+
pd_peer_init_port: 7301
15+
pd_peer_alloc_port: 7401
16+
pd_buffer_size: 2147483648 # 2GB
17+
pd_buffer_device: "cuda"
18+
nixl_backends: [UCX]
19+
20+
save_decode_cache: true
21+
save_unfull_chunk: true
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
local_cpu: True
2+
max_local_cpu_size: 5
3+
4+
remote_url: "lm://localhost:6800"
5+
remote_serde: "cachegen"
6+
7+
retrieve_locations: ["LocalCPUBackend", "RemoteBackend"]
8+
9+
enable_pd: True
10+
transfer_channel: "nixl"
11+
pd_role: "sender"
12+
pd_proxy_host: "localhost"
13+
pd_proxy_port: 7500
14+
pd_buffer_size: 1073741824 # 1GB
15+
pd_buffer_device: "cuda"
16+
nixl_backends: [UCX]
17+
18+
save_unfull_chunk: true
19+

lmcache/v1/cache_engine.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ def __init__(
180180
# at decoder.
181181
self.remove_after_retrieve = config.enable_pd and config.pd_role == "receiver"
182182

183+
# asymmetric store/retrieve location can be specified
184+
# this is typically used (but not limited) in PD system
185+
self.store_location = config.store_location
186+
self.retrieve_locations = config.retrieve_locations
187+
183188
self.num_layers = metadata.kv_shape[0]
184189
self.fmt = None
185190
if self.use_layerwise:
@@ -532,7 +537,10 @@ def store(
532537
# TODO: we implicitly rely on batched_put to call ref_count_down
533538
# this management should be done in a cleaner way
534539
self.storage_manager.batched_put(
535-
keys, memory_objs, transfer_spec=transfer_spec
540+
keys,
541+
memory_objs,
542+
transfer_spec=transfer_spec,
543+
location=self.store_location,
536544
)
537545

538546
self.stats_monitor.on_store_finished(
@@ -640,7 +648,9 @@ def store_layer(
640648

641649
keys_multi_layer = key.split_layers(self.num_layers)
642650
# Only check the first layer
643-
if self.storage_manager.contains(keys_multi_layer[0]):
651+
if self.storage_manager.contains(
652+
keys_multi_layer[0], self.retrieve_locations
653+
):
644654
continue
645655

646656
# Allocate the memory object
@@ -715,7 +725,9 @@ def store_layer(
715725
for layer_id in range(self.num_layers):
716726
yield
717727
next(mem_obj_generator)
718-
self.storage_manager.batched_put(keys[layer_id], memory_objs[layer_id])
728+
self.storage_manager.batched_put(
729+
keys[layer_id], memory_objs[layer_id], location=self.store_location
730+
)
719731

720732
tot_time = time.perf_counter() - t_start
721733
logger.info(
@@ -848,7 +860,7 @@ def retrieve(
848860
for key, memory_obj, _, _ in reordered_chunks:
849861
if self.remove_after_retrieve and not self._is_passive():
850862
assert self.storage_manager is not None
851-
self.storage_manager.remove(key)
863+
self.storage_manager.remove(key, self.retrieve_locations)
852864
if not self.async_loading:
853865
memory_obj.ref_count_down()
854866

@@ -956,7 +968,9 @@ def retrieve_layer(
956968
keys_multi_layer = key.split_layers(self.num_layers)
957969

958970
# NOTE: Only check the first layer
959-
if current_location := self.storage_manager.contains(keys_multi_layer[0]):
971+
if current_location := self.storage_manager.contains(
972+
keys_multi_layer[0], self.retrieve_locations
973+
):
960974
if location is None:
961975
location = current_location
962976
else:
@@ -1082,6 +1096,9 @@ def lookup(
10821096
assert hashes is not None
10831097
lookup_stats = self.stats_monitor.on_lookup_request(sum(offsets))
10841098

1099+
if search_range is None:
1100+
search_range = self.retrieve_locations
1101+
10851102
res = 0
10861103
try:
10871104
chunk_info_iterator = self.token_database.process_tokens(
@@ -1243,6 +1260,9 @@ def async_lookup_and_prefetch(
12431260
keys: list[CacheEngineKey] = []
12441261
cum_chunk_lengths = [0]
12451262

1263+
if search_range is None:
1264+
search_range = self.retrieve_locations
1265+
12461266
# TODO(Jiayi): make token database able to return list.
12471267
for start, end, key in self.token_database.process_tokens(
12481268
tokens=tokens,

lmcache/v1/config.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120
},
121121
"blend_min_tokens": {"type": int, "default": 256, "env_converter": int},
122122
"blend_special_str": {"type": str, "default": " # # ", "env_converter": str},
123+
"retrieve_locations": {"type": Optional[list[str]], "default": None},
124+
"store_location": {"type": Optional[str], "default": None},
123125
# P2P configurations
124126
"enable_p2p": {
125127
"type": bool,
@@ -544,11 +546,6 @@ def _validate_config(self):
544546
assert self.pd_role is not None
545547
assert self.pd_buffer_size is not None
546548
assert self.pd_buffer_device is not None
547-
548-
assert self.remote_url is None, "PD only supports remote_url=None"
549-
assert self.save_decode_cache is False, (
550-
"PD only supports save_decode_cache=False"
551-
)
552549
assert self.enable_p2p is False, "PD only supports enable_p2p=False"
553550

554551
# PD requires save_unfull_chunk=True for complete KV cache transfer
@@ -568,6 +565,19 @@ def _validate_config(self):
568565
"including partial chunks will be transferred to decode node"
569566
)
570567

568+
# for receiver, PDBackend is for retrieve location
569+
# can't take PDBackend as store location
570+
# as PDBackend is now one way from producer to receiver only
571+
if self.pd_role == "receiver":
572+
assert self.store_location != "PDBackend", (
573+
"store_location cannot be PDBackend for receiver"
574+
)
575+
assert self.retrieve_locations in (None, ["PDBackend"]), (
576+
"for pd receiver, "
577+
'retrieve_locations are expected to be ["PDBackend"], '
578+
f"now, it is {self.retrieve_locations}"
579+
)
580+
571581
if enable_nixl_storage:
572582
assert self.extra_config.get("nixl_backend") is not None
573583
assert self.extra_config.get("nixl_pool_size") is not None

0 commit comments

Comments
 (0)