diff --git a/ucm/integration/sglang/ucm_connector.py b/ucm/integration/sglang/ucm_connector.py index 60c2ad9de..139ea5195 100644 --- a/ucm/integration/sglang/ucm_connector.py +++ b/ucm/integration/sglang/ucm_connector.py @@ -48,7 +48,9 @@ class UnifiedCacheStoreConfig: @staticmethod def load_from_config( - storage_config: "HiCacheStorageConfig", mem_pool_host: "HostKVCache" + storage_config: "HiCacheStorageConfig", + mem_pool_host: "HostKVCache", + host_pool_name: str = "", ) -> "UnifiedCacheStoreConfig": extra = dict(getattr(storage_config, "extra_config", None) or {}) if "kv_connector_extra_config" not in extra: @@ -86,9 +88,19 @@ def load_from_config( cfg = dict(ucm_cfg) cfg["store_pipeline"] = "Posix" - cfg["storage_backends"] = [ - path for path in cfg["storage_backends"].split(":") if path - ] + if host_pool_name: + for path in cfg["storage_backends"].split(":"): + if path: + os.makedirs(f"{path}/{host_pool_name}", exist_ok=True) + cfg["storage_backends"] = [ + f"{path}/{host_pool_name}" + for path in cfg["storage_backends"].split(":") + if path + ] + else: + cfg["storage_backends"] = [ + path for path in cfg["storage_backends"].split(":") if path + ] cfg["device_id"] = get_world_group().local_rank cfg["tensor_size"] = tensor_size cfg["shard_size"] = block_size @@ -125,11 +137,12 @@ def from_hicache( cls, storage_config: "HiCacheStorageConfig", mem_pool_host: "HostKVCache", + host_pool_name: str = "", ) -> "SglangUcmConnector": if mem_pool_host is None: raise ValueError("mem_pool_host must be provided for UnifiedCache") ucm_store_config = UnifiedCacheStoreConfig.load_from_config( - storage_config, mem_pool_host + storage_config, mem_pool_host, host_pool_name ) store = UcmConnectorFactoryV1.create_connector( ucm_store_config.name, ucm_store_config.config, ucm_store_config.module_path @@ -159,6 +172,31 @@ def _get_physical_key(self, logical_key: str) -> str: def _get_physical_keys(self, logical_keys: List[str]) -> List[str]: return [self._get_physical_key(key) for key in logical_keys] + def _generate_ptrs(self, indices): + assert len(indices) % self.page_size == 0 + ptr_list = [] + k_buffer_data_ptr = self.mem_pool_host.k_buffer.data_ptr() + v_buffer_data_ptr = self.mem_pool_host.v_buffer.data_ptr() + indices = indices.tolist() + for index in range(0, len(indices), self.page_size): + k_ptr = ( + k_buffer_data_ptr + + indices[index] + * self.mem_pool_host.layer_num + * self.mem_pool_host.kv_lora_rank + * self.dtype.itemsize + ) + ptr_list.append(k_ptr) + v_ptr = ( + v_buffer_data_ptr + + indices[index] + * self.mem_pool_host.layer_num + * self.mem_pool_host.qk_rope_head_dim + * self.dtype.itemsize + ) + ptr_list.append(v_ptr) + return ptr_list + def _generate_task( self, encoded_keys: List[bytes], @@ -168,9 +206,12 @@ def _generate_task( return [], [], [] shard_index_list = [0] * len(encoded_keys) - ptr_list, _ = self.mem_pool_host.get_page_buffer_meta(host_indices) + if self.mem_pool_host.layout != "page_first_kv_split": + ptr_list, _ = self.mem_pool_host.get_page_buffer_meta(host_indices) + else: + ptr_list = self._generate_ptrs(host_indices) - if not self.is_mla: + if not self.is_mla or self.mem_pool_host.layout == "page_first_kv_split": ptr_list = [list(p) for p in zip(ptr_list[::2], ptr_list[1::2])] else: ptr_list = [[p] for p in ptr_list] diff --git a/ucm/integration/sglang/unifiedcache_store.py b/ucm/integration/sglang/unifiedcache_store.py index 259719fc7..9dc462c0c 100644 --- a/ucm/integration/sglang/unifiedcache_store.py +++ b/ucm/integration/sglang/unifiedcache_store.py @@ -6,6 +6,10 @@ HiCacheStorage, HiCacheStorageConfig, HiCacheStorageExtraInfo, + PoolHitPolicy, + PoolName, + PoolTransfer, + PoolTransferResult, ) from sglang.srt.mem_cache.memory_pool_host import HostKVCache @@ -29,6 +33,8 @@ def __init__( self.connector: Optional[SglangUcmConnector] = None self.store = None self.mem_pool_host: Optional[HostKVCache] = None + self.host_connectors = {} + self.registered_pools = {} if isinstance(context, HostKVCache): self.register_mem_pool_host(context) @@ -43,10 +49,17 @@ def _ensure_initialized(self) -> SglangUcmConnector: def register_mem_pool_host(self, mem_pool_host: HostKVCache): super().register_mem_pool_host(mem_pool_host) - if mem_pool_host.layout != "page_first": + allowed_layouts = [ + "page_first", + "page_first_direct", + "page_head", + "page_first_kv_split", + ] + if mem_pool_host.layout not in allowed_layouts: raise ValueError( - "UnifiedCacheStore currently requires --hicache-mem-layout page_first, " - f"got {mem_pool_host.layout!r}." + f"UnifiedCacheStore only supports layouts {allowed_layouts}, " + f"but you are using {mem_pool_host.layout!r}. " + "Please set --hicache-mem-layout to right layout." ) self.mem_pool_host = mem_pool_host @@ -136,3 +149,80 @@ def close(self) -> None: def get_stats(self): connector = self.connector return None if connector is None else connector.get_stats() + + def register_mem_host_pool_v2(self, host_pool: HostKVCache, host_pool_name): + # TODO: Check if it's nessacery + if host_pool_name == PoolName.KV: + return + self.registered_pools[host_pool_name] = host_pool + allowed_layouts = [ + "page_first", + "page_first_direct", + "page_head", + "page_first_kv_split", + ] + if host_pool.layout not in allowed_layouts: + raise ValueError( + f"UnifiedCacheStore only supports layouts {allowed_layouts}, " + f"but you are using {host_pool.layout!r}. " + "Please set --hicache-mem-layout to right layout." + ) + + connector = SglangUcmConnector.from_hicache( + self.storage_config, host_pool, host_pool_name + ) + self.host_connectors[host_pool_name] = connector + + def batch_exists_v2( + self, + keys: List[str], + pool_transfers: Optional[List[PoolTransfer]] = None, + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> PoolTransferResult: + kv_pages = self._ensure_initialized().batch_exists(keys, extra_info) + + hit_count: dict = {PoolName.KV: kv_pages} if kv_pages else {} + final_pages = kv_pages + + for transfer in pool_transfers or []: + if final_pages == 0: + break + name = transfer.name + ex = self.host_connectors[name].batch_exists(keys, extra_info) + if transfer.hit_policy == PoolHitPolicy.ALL_PAGES: + boundary = ex + else: # trailing_pages + trailing = max(1, len(transfer.keys) if transfer.keys else 1) + for prefix_len in range(kv_pages, 0, -1): + if ex <= prefix_len: + boundary = prefix_len if ex >= trailing else 0 + break + if boundary: + hit_count[name] = boundary + final_pages = min(final_pages, boundary) + + return PoolTransferResult(final_pages, hit_count) + + def batch_get_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> dict: + results: dict = {} + for transfer in transfers: + results[transfer.name] = self.host_connectors[transfer.name].batch_get_v1( + transfer.keys, transfer.host_indices + ) + return results + + def batch_set_v2( + self, + transfers: List[PoolTransfer], + extra_info: Optional[HiCacheStorageExtraInfo] = None, + ) -> dict: + results: dict = {} + for transfer in transfers: + results[transfer.name] = self.host_connectors[transfer.name].batch_set_v1( + transfer.keys, transfer.host_indices + ) + return results