Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions ucm/integration/sglang/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class UnifiedCacheStoreConfig:

@staticmethod
def load_from_config(
storage_config: "HiCacheStorageConfig", mem_pool_host: "HostKVCache"
storage_config: "HiCacheStorageConfig",
mem_pool_host: "HostKVCache",
host_pool_name: str = "",
) -> "UnifiedCacheStoreConfig":
extra = dict(getattr(storage_config, "extra_config", None) or {})
if "kv_connector_extra_config" not in extra:
Expand Down Expand Up @@ -86,9 +88,19 @@ def load_from_config(

cfg = dict(ucm_cfg)
cfg["store_pipeline"] = "Posix"
cfg["storage_backends"] = [
path for path in cfg["storage_backends"].split(":") if path
]
if host_pool_name:
for path in cfg["storage_backends"].split(":"):
if path:
os.makedirs(f"{path}/{host_pool_name}", exist_ok=True)
cfg["storage_backends"] = [
f"{path}/{host_pool_name}"
for path in cfg["storage_backends"].split(":")
if path
]
else:
cfg["storage_backends"] = [
path for path in cfg["storage_backends"].split(":") if path
]
cfg["device_id"] = get_world_group().local_rank
cfg["tensor_size"] = tensor_size
cfg["shard_size"] = block_size
Expand Down Expand Up @@ -125,11 +137,12 @@ def from_hicache(
cls,
storage_config: "HiCacheStorageConfig",
mem_pool_host: "HostKVCache",
host_pool_name: str = "",
) -> "SglangUcmConnector":
if mem_pool_host is None:
raise ValueError("mem_pool_host must be provided for UnifiedCache")
ucm_store_config = UnifiedCacheStoreConfig.load_from_config(
storage_config, mem_pool_host
storage_config, mem_pool_host, host_pool_name
)
store = UcmConnectorFactoryV1.create_connector(
ucm_store_config.name, ucm_store_config.config, ucm_store_config.module_path
Expand Down Expand Up @@ -159,6 +172,31 @@ def _get_physical_key(self, logical_key: str) -> str:
def _get_physical_keys(self, logical_keys: List[str]) -> List[str]:
return [self._get_physical_key(key) for key in logical_keys]

def _generate_ptrs(self, indices):
assert len(indices) % self.page_size == 0
ptr_list = []
k_buffer_data_ptr = self.mem_pool_host.k_buffer.data_ptr()
v_buffer_data_ptr = self.mem_pool_host.v_buffer.data_ptr()
indices = indices.tolist()
for index in range(0, len(indices), self.page_size):
k_ptr = (
k_buffer_data_ptr
+ indices[index]
* self.mem_pool_host.layer_num
* self.mem_pool_host.kv_lora_rank
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
v_ptr = (
v_buffer_data_ptr
+ indices[index]
* self.mem_pool_host.layer_num
* self.mem_pool_host.qk_rope_head_dim
* self.dtype.itemsize
)
ptr_list.append(v_ptr)
return ptr_list

def _generate_task(
self,
encoded_keys: List[bytes],
Expand All @@ -168,9 +206,12 @@ def _generate_task(
return [], [], []

shard_index_list = [0] * len(encoded_keys)
ptr_list, _ = self.mem_pool_host.get_page_buffer_meta(host_indices)
if self.mem_pool_host.layout != "page_first_kv_split":
ptr_list, _ = self.mem_pool_host.get_page_buffer_meta(host_indices)
else:
ptr_list = self._generate_ptrs(host_indices)

if not self.is_mla:
if not self.is_mla or self.mem_pool_host.layout == "page_first_kv_split":
ptr_list = [list(p) for p in zip(ptr_list[::2], ptr_list[1::2])]
else:
ptr_list = [[p] for p in ptr_list]
Expand Down
96 changes: 93 additions & 3 deletions ucm/integration/sglang/unifiedcache_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
PoolHitPolicy,
PoolName,
PoolTransfer,
PoolTransferResult,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache

Expand All @@ -29,6 +33,8 @@ def __init__(
self.connector: Optional[SglangUcmConnector] = None
self.store = None
self.mem_pool_host: Optional[HostKVCache] = None
self.host_connectors = {}
self.registered_pools = {}

if isinstance(context, HostKVCache):
self.register_mem_pool_host(context)
Expand All @@ -43,10 +49,17 @@ def _ensure_initialized(self) -> SglangUcmConnector:

def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
if mem_pool_host.layout != "page_first":
allowed_layouts = [
"page_first",
"page_first_direct",
"page_head",
"page_first_kv_split",
]
if mem_pool_host.layout not in allowed_layouts:
raise ValueError(
"UnifiedCacheStore currently requires --hicache-mem-layout page_first, "
f"got {mem_pool_host.layout!r}."
f"UnifiedCacheStore only supports layouts {allowed_layouts}, "
f"but you are using {mem_pool_host.layout!r}. "
"Please set --hicache-mem-layout to right layout."
)

self.mem_pool_host = mem_pool_host
Expand Down Expand Up @@ -136,3 +149,80 @@ def close(self) -> None:
def get_stats(self):
connector = self.connector
return None if connector is None else connector.get_stats()

def register_mem_host_pool_v2(self, host_pool: HostKVCache, host_pool_name):
# TODO: Check if it's nessacery
if host_pool_name == PoolName.KV:
return
self.registered_pools[host_pool_name] = host_pool
allowed_layouts = [
"page_first",
"page_first_direct",
"page_head",
"page_first_kv_split",
]
if host_pool.layout not in allowed_layouts:
raise ValueError(
f"UnifiedCacheStore only supports layouts {allowed_layouts}, "
f"but you are using {host_pool.layout!r}. "
"Please set --hicache-mem-layout to right layout."
)

connector = SglangUcmConnector.from_hicache(
self.storage_config, host_pool, host_pool_name
)
self.host_connectors[host_pool_name] = connector

def batch_exists_v2(
self,
keys: List[str],
pool_transfers: Optional[List[PoolTransfer]] = None,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> PoolTransferResult:
kv_pages = self._ensure_initialized().batch_exists(keys, extra_info)

hit_count: dict = {PoolName.KV: kv_pages} if kv_pages else {}
final_pages = kv_pages

for transfer in pool_transfers or []:
if final_pages == 0:
break
name = transfer.name
ex = self.host_connectors[name].batch_exists(keys, extra_info)
if transfer.hit_policy == PoolHitPolicy.ALL_PAGES:
boundary = ex
else: # trailing_pages
trailing = max(1, len(transfer.keys) if transfer.keys else 1)
for prefix_len in range(kv_pages, 0, -1):
if ex <= prefix_len:
boundary = prefix_len if ex >= trailing else 0
break
if boundary:
hit_count[name] = boundary
final_pages = min(final_pages, boundary)

return PoolTransferResult(final_pages, hit_count)

def batch_get_v2(
self,
transfers: List[PoolTransfer],
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> dict:
results: dict = {}
for transfer in transfers:
results[transfer.name] = self.host_connectors[transfer.name].batch_get_v1(
transfer.keys, transfer.host_indices
)
return results

def batch_set_v2(
self,
transfers: List[PoolTransfer],
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> dict:
results: dict = {}
for transfer in transfers:
results[transfer.name] = self.host_connectors[transfer.name].batch_set_v1(
transfer.keys, transfer.host_indices
)
return results
Loading