diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index e89fc1f4b..0cd480b8d 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -6,7 +6,7 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import numpy as np import torch @@ -15,11 +15,20 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.parallel_state import get_world_group from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from ucm.integration.vllm.device import create_device from ucm.logger import init_logger @@ -51,6 +60,31 @@ class RequestMeta: token_processed: int = 0 +@dataclass +class HMARequestMeta(RequestMeta): + """RequestMeta extended with per-group block tracking for hybrid models. + + The inherited fields (``ucm_block_ids``, ``hbm_hit_block_num``, + ``total_hit_block_num``, ``num_token_ids``, ``vllm_block_ids``, + ``token_processed``) keep their original semantics and mirror the + full-attention group exactly, so dispatch/load/save paths inherited from + :class:`UCMDirectConnector` keep working. + + The two new fields are 2D lists indexed by the original + ``kv_cache_config.kv_cache_groups`` order (i.e. ``[group_id]``): + - ``group_ucm_block_ids[gid]``: full block hashes obtained by hashing + ``request.all_token_ids`` with group ``gid``'s own block size and + chain seed. ``group_ucm_block_ids[full_attn_group_id]`` equals the + inherited ``ucm_block_ids``. + - ``group_vllm_block_ids[gid]``: per-group VLLM physical block ids; this + is initialised as an empty list per group here and populated later by + the dispatch path (still a TODO for HMA dump/load). + """ + + group_ucm_block_ids: list[list[bytes]] = field(default_factory=list) + group_vllm_block_ids: list[list[int]] = field(default_factory=list) + + @dataclass class RequestDispatchMeta: load_block_ids: tuple[ @@ -61,12 +95,17 @@ class RequestDispatchMeta: class KVCacheLayout: def __init__( - self, kvcaches, use_layerwise: bool, vllm_config: "VllmConfig" + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", ) -> None: # each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index) self.base_ptrs: np.ndarray # (n_layers, n_ptrs) self.tensor_size_lists: np.ndarray # (n_layers, n_tensor_sizes) - self.use_layerwise = use_layerwise + self.use_layerwise = ucm_config.get("use_layerwise", False) + self.kv_cache_config = kv_cache_config self.vllm_config = vllm_config self.pp_size = self.vllm_config.parallel_config.pipeline_parallel_size self.num_hidden_layers = getattr( @@ -78,6 +117,8 @@ def __init__( name: extract_layer_index(name) for name in kvcaches.keys() } self.first_layer_id = next(iter(self.layer_name_to_id.values())) + self.num_blocks = self.kv_cache_config.num_blocks + self.layer_name_to_kv_cache_spec = layer_name_to_kv_cache_spec(kv_cache_config) self._build_layout(kvcaches) def _build_layout(self, kvcaches): @@ -192,8 +233,15 @@ class UCMDirectConnector(KVConnectorBase_V1): load -> forward -> save """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.use_layerwise = False self.kv_caches: dict[str, torch.Tensor] = {} self.local_rank = ( @@ -212,6 +260,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.head_size = vllm_config.model_config.get_head_size() self.element_size = vllm_config.model_config.dtype.itemsize + self.use_compress = hasattr( + self._vllm_config.model_config.hf_config, "compress_ratios" + ) + self._kv_cache_config = kv_cache_config if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") @@ -365,7 +417,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for i, tensor in enumerate(sample_kv_layer): logger.info(f"kv cache shape {i}: {tensor.shape}") self.kv_cache_layout = KVCacheLayout( - self.kv_caches, self.use_layerwise, self._vllm_config + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, ) self.block_data_size = self.kv_cache_layout.block_size self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id @@ -921,6 +976,13 @@ def wait_for_save(self) -> None: if self.enable_event_sync: self.device.destroy_event_handles() + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return False, None + class UCMCPConnector(UCMLayerWiseConnector): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): @@ -1149,9 +1211,797 @@ def get_num_new_matched_tokens(self, request, num_computed_tokens): return 0, False -class UCMConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) +def layer_name_to_kv_cache_spec( + kv_cache_config: KVCacheConfig, +) -> dict[str, list[KVCacheSpec]]: + """Map each model layer name to its concrete KVCacheSpec. + + Handles merged group specs and UniformTypeKVCacheSpecs (per-layer + ``kv_cache_specs`` entries). + """ + out: dict[str, list[KVCacheSpec]] = defaultdict(list) + for group in kv_cache_config.kv_cache_groups: + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + by_name = spec.kv_cache_specs + for name in group.layer_names: + out[name].append(by_name[name]) + else: + for name in group.layer_names: + out[name].append(spec) + return out + + +def block_size_from_kv_cache_spec(spec: KVCacheSpec) -> int: + """Token block size used for KV scheduling / hashing for one group spec.""" + block_size = 0 + if isinstance(spec, UniformTypeKVCacheSpecs): + block_size = next(iter(spec.kv_cache_specs.values())).block_size + else: + block_size = spec.block_size + + if current_platform.device_type == "npu" and hasattr(spec, "compress_ratio"): + block_size *= spec.compress_ratio + + return block_size + + +def sliding_window_from_kv_cache_spec(spec: KVCacheSpec) -> Optional[int]: + """Return the sliding window size of a group spec, or None for full attention. + + A group is treated as full attention iff its sample spec exposes no + ``sliding_window`` attribute or that attribute is ``None``. + """ + if isinstance(spec, UniformTypeKVCacheSpecs): + sample = next(iter(spec.kv_cache_specs.values())) + return getattr(sample, "sliding_window", None) + return getattr(spec, "sliding_window", None) + + +@dataclass +class GroupInfo: + """Per-group metadata used by :class:`KVCacheGroupManager`.""" + + group_id: int + block_size: int + # None for full-attention groups, otherwise the window length in tokens. + sliding_window: Optional[int] + layer_names: tuple[str, ...] + # Independent hash chain seed per group (see ``KVCacheGroupManager``). + seed: bytes + + @property + def is_full_attention(self) -> bool: + return self.sliding_window is None + + +class KVCacheGroupManager: + """Group-aware hashing and lookup for hybrid (HMA) connectors. + + Splits ``kv_cache_config.kv_cache_groups`` into full-attention groups + (one or more) and sliding-window groups, derives a per-group hash chain + seed, and exposes a two-stage lookup that: + + 1. For every full-attention group, hashes ``request.all_token_ids`` with + that group's block size and runs ``store.lookup_on_prefix`` on the + blocks beyond its own ``hbm_hit_block_num``. The candidate hits (in + tokens) are min'd across full-attn groups and rounded down to + ``lcm_block_size``. + 2. For each sliding-window group, re-hashes the same prefix with that + group's own block size and verifies the last + ``max(1, sliding_window // block_size)`` blocks all exist via + ``store.lookup`` (when ``block_size > sliding_window`` — e.g. on + Ascend — the last single block already covers the SW). If any + sliding-window group fails this check, the whole external hit is + downgraded to zero. + """ + + def __init__( + self, + kv_cache_config: "KVCacheConfig", + request_hasher: "RequestHasher", + base_seed: bytes, + ) -> None: + self.request_hasher = request_hasher + # Indexed by original group_id; positions match + # ``kv_cache_config.kv_cache_groups``. + self.groups_by_id: list[GroupInfo] = [] + # All groups whose spec has no sliding_window. Order follows group_id. + self.full_attn_groups: list[GroupInfo] = [] + self.sliding_window_groups: list[GroupInfo] = [] + + for group_id, group in enumerate(kv_cache_config.kv_cache_groups): + spec = group.kv_cache_spec + block_size = block_size_from_kv_cache_spec(spec) + sliding_window = sliding_window_from_kv_cache_spec(spec) + # Mix group_id into the hash chain seed so two groups with the + # same block_size do not collide in the underlying store. + seed = request_hasher((b"UCM_GROUP_SEED", base_seed, group_id)) + info = GroupInfo( + group_id=group_id, + block_size=block_size, + sliding_window=sliding_window, + layer_names=tuple(group.layer_names), + seed=seed, + ) + self.groups_by_id.append(info) + if info.is_full_attention: + self.full_attn_groups.append(info) + else: + self.sliding_window_groups.append(info) + + assert len(self.full_attn_groups) >= 1, ( + "UCMHMAConnector expects at least one full-attention group in " + "kv_cache_config.kv_cache_groups." + ) + + # Resume points must be aligned to the LCM of every group's + # block_size so that per-group block accounting (including each + # full-attn group's lookup result and every SW group's tail slice) + # lands on a clean block boundary. + all_block_sizes = [g.block_size for g in self.groups_by_id] + self.lcm_block_size: int = math.lcm(*all_block_sizes) + + for g in self.groups_by_id: + assert self.lcm_block_size % g.block_size == 0, ( + f"group {g.group_id} block_size={g.block_size} does not " + f"divide LCM={self.lcm_block_size}" + ) + for sw in self.sliding_window_groups: + # The dump path stores only ``[B - sliding_window, B)`` for each + # LCM boundary B. Requiring ``sliding_window <= lcm_block_size`` + # guarantees consecutive boundaries' tails do not overlap, so the + # incremental dump can append each boundary's tail without + # cross-boundary deduplication. + assert sw.sliding_window <= self.lcm_block_size, ( + f"sliding window group {sw.group_id} sliding_window=" + f"{sw.sliding_window} > lcm_block_size=" + f"{self.lcm_block_size}; not supported." + ) + # On some backends (e.g. Ascend) ``block_size`` can exceed + # ``sliding_window``; in that case a single block already holds + # more than a window of tokens, so we will treat the last block + # as the SW tail and skip the divisibility check. + if sw.block_size >= sw.sliding_window: + continue + if sw.sliding_window % sw.block_size != 0: + raise ValueError( + f"Sliding window group {sw.group_id} sliding_window=" + f"{sw.sliding_window} is not a multiple of block_size=" + f"{sw.block_size}." + ) + + logger.info( + "KVCacheGroupManager initialized: " + f"lcm_block_size={self.lcm_block_size}, " + f"full_attn_groups=" + f"{[(g.group_id, g.block_size) for g in self.full_attn_groups]}, " + f"sliding_window_groups=" + f"{[(g.group_id, g.block_size, g.sliding_window) for g in self.sliding_window_groups]}" + ) + + @property + def num_groups(self) -> int: + return len(self.groups_by_id) + + def compute_block_hashes( + self, group: GroupInfo, token_ids: list[int] + ) -> list[bytes]: + """Hash ``token_ids`` into per-block ids using ``group``'s chain seed.""" + ret: list[bytes] = [] + parent = group.seed + block_size = group.block_size + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + if len(block_token_ids) < block_size: + break + hash_value = self.request_hasher((parent, tuple(block_token_ids))) + parent = hash_value + ret.append(hash_value) + return ret + + def compute_all_group_block_ids(self, token_ids: list[int]) -> list[list[bytes]]: + """Compute full block hashes for every group, indexed by group_id. + + ``ret[gid]`` covers all aligned blocks of ``token_ids`` using group + ``gid``'s ``block_size`` and chain seed. The trailing partial block + (if any) is dropped, matching :meth:`compute_block_hashes`. + """ + return [self.compute_block_hashes(g, token_ids) for g in self.groups_by_id] + + def lookup_external_hit_tokens( + self, + num_computed_tokens: int, + store: "UcmKVStoreBaseV1", + group_block_ids: list[list[bytes]], + ) -> tuple[int, int]: + """Two-stage HMA lookup using precomputed per-group hashes. + + ``group_block_ids`` must have one entry per group, indexed by the + original ``group_id`` (see :meth:`compute_all_group_block_ids`). + + Stage 1 — every full-attention group runs ``lookup_on_prefix`` + beyond its own ``hbm_hit_block_num``; the candidate hits are taken + as a min and rounded down to ``lcm_block_size`` so the final + external hit is consistent across all full-attn groups and aligns + to the kv-cache page granularity expected by the scheduler. + + Stage 2 — every sliding-window group must have the last + ``sliding_window // block_size`` blocks before ``total_hit_tokens`` + present in the store; if any group fails, the whole external hit + is downgraded to zero. + + Returns: + Tuple of + - ``external_hit_tokens``: tokens hit beyond ``num_computed_tokens``, + aligned to ``lcm_block_size``. ``0`` if any check fails. + - ``external_hit_lcm_blocks``: ``external_hit_tokens // + lcm_block_size`` (also ``0`` on downgrade). + """ + assert len(group_block_ids) == self.num_groups, ( + f"group_block_ids length {len(group_block_ids)} does not match " + f"num_groups {self.num_groups}" + ) + assert num_computed_tokens % self.lcm_block_size == 0, ( + f"num_computed_tokens={num_computed_tokens} is not aligned to " + f"lcm_block_size={self.lcm_block_size}" + ) + + # Stage 1: each full-attn group contributes a candidate hit count. + candidates: list[int] = [] + for fa in self.full_attn_groups: + fa_block_ids = group_block_ids[fa.group_id] + fa_hbm_blocks = num_computed_tokens // fa.block_size + fa_external = fa_block_ids[fa_hbm_blocks:] + if not fa_external: + candidates.append(0) + continue + try: + fa_hit_blocks = store.lookup_on_prefix(fa_external) + 1 + except Exception as e: + logger.error( + f"full-attn group {fa.group_id} lookup error. " + f"{type(e).__name__}: {e}" + ) + candidates.append(0) + continue + candidates.append(max(fa_hit_blocks, 0) * fa.block_size) + + # Resume boundary must be a multiple of lcm_block_size so every + # group's tail/dispatch slicing lands on a real block boundary. + min_external_hit_tokens = min(candidates) + external_hit_tokens = ( + min_external_hit_tokens // self.lcm_block_size + ) * self.lcm_block_size + if external_hit_tokens <= 0: + return 0, 0 + + # Stage 2: every SW group's tail window must be in the store. + total_hit_tokens = num_computed_tokens + external_hit_tokens + for sw in self.sliding_window_groups: + # When ``block_size > sliding_window`` (e.g. on Ascend) a single + # block already covers more than a window of tokens, so we use + # the last block as the SW tail (tail_count = 1). Otherwise + # ``sliding_window`` is a multiple of ``block_size`` (validated + # in __init__) and we take exactly ``sliding_window/block_size`` + # blocks. + tail_count = max(1, sw.sliding_window // sw.block_size) + min_required_tokens = tail_count * sw.block_size + if total_hit_tokens < min_required_tokens: + logger.info( + f"sliding window group {sw.group_id} tail check skipped: " + f"total_hit_tokens={total_hit_tokens} < " + f"min_required={min_required_tokens} " + f"(sliding_window={sw.sliding_window}, " + f"block_size={sw.block_size}), downgrade to 0." + ) + return 0, 0 + + sw_block_ids = group_block_ids[sw.group_id] + # ``sw.block_size`` divides ``total_hit_tokens`` because + # ``total_hit_tokens`` is a multiple of ``lcm_block_size`` and + # ``lcm_block_size`` is divisible by every group's block_size + # (validated in __init__). + num_blocks_in_total_hit = total_hit_tokens // sw.block_size + tail_block_ids = sw_block_ids[ + num_blocks_in_total_hit - tail_count : num_blocks_in_total_hit + ] + try: + results = store.lookup(tail_block_ids) + except Exception as e: + logger.error( + f"sliding window group {sw.group_id} lookup error. " + f"{type(e).__name__}: {e}" + ) + return 0, 0 + if not all(results): + logger.info( + f"sliding window group {sw.group_id} tail miss: " + f"hits={results}, downgrade external hit to 0." + ) + return 0, 0 + + return external_hit_tokens, external_hit_tokens // self.lcm_block_size + + +class HMAKVCacheLayout(KVCacheLayout): + def __init__( + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + ): + super().__init__(kvcaches, ucm_config, vllm_config, kv_cache_config) + + def _build_layout(self, kvcaches): + base_ptrs = [] + tensor_size_lists = [] + + for raw_tensor in self.kv_cache_config.kv_cache_tensors: + ptrs = [] + tensor_sizes = [] + + if raw_tensor.shared_by: + sample_layer_name = raw_tensor.shared_by[0] + kv_layer = kvcaches.get(sample_layer_name) + if kv_layer is None: + logger.warning( + f"kv_layer {sample_layer_name} not found in kvcaches" + ) + continue + kv_cache_spec = self.layer_name_to_kv_cache_spec[sample_layer_name][0] + if isinstance(kv_layer, torch.Tensor): + ptrs.append(kv_layer.data_ptr()) + tensor_sizes.append(kv_cache_spec.page_size_bytes) + elif isinstance(kv_layer, (tuple, list)): + ptrs.append(kv_layer[0].data_ptr()) + tensor_sizes.append(kv_cache_spec.page_size_bytes) + else: + logger.warning(f"unsupported kv_layer type: {type(kv_layer)}") + + if not ptrs and not tensor_sizes: + continue + + base_ptrs.append(ptrs) + tensor_size_lists.append(tensor_sizes) + + self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) + self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) + + logger.info( + f"base_ptrs: {self.base_ptrs.shape}, tensor_size_lists: {self.tensor_size_lists.shape}" + ) + + +class AscendDSV4Layout(HMAKVCacheLayout): + def __init__( + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + ): + super().__init__(kvcaches, ucm_config, vllm_config, kv_cache_config) + self.indexer_scale_size_bytes = 0 + for _, layer_specs in self.layer_name_to_kv_cache_spec.items(): + for spec in layer_specs: + if hasattr(spec, "indexer_scale_size_bytes"): + self.indexer_scale_size_bytes = spec.indexer_scale_size_bytes + break + + def _build_layout(self, kvcaches): + self.indexer_scale_size_bytes = 0 + for _, layer_specs in self.layer_name_to_kv_cache_spec.items(): + for spec in layer_specs: + if hasattr(spec, "indexer_scale_size_bytes"): + self.indexer_scale_size_bytes = spec.indexer_scale_size_bytes + break + + base_ptrs = [] + tensor_size_lists = [] + + for raw_tensor in self.kv_cache_config.kv_cache_tensors: + ptrs = [] + tensor_sizes = [] + kv_size = raw_tensor.size - self.indexer_scale_size_bytes * self.num_blocks + + if raw_tensor.shared_by: + sample_layer_name = raw_tensor.shared_by[0] + kv_layer = kvcaches.get(sample_layer_name) + if kv_layer is None: + logger.warning( + f"kv_layer {sample_layer_name} not found in kvcaches" + ) + continue + kv_cache_specs = self.layer_name_to_kv_cache_spec[sample_layer_name] + if isinstance(kv_layer, (tuple, list)): + ptrs.append(kv_layer[0].data_ptr()) + tensor_sizes.append( + kv_cache_specs[0].page_size_bytes + - self.indexer_scale_size_bytes + ) + ptrs.append(kv_layer[0].data_ptr() + kv_size) + tensor_sizes.append(self.indexer_scale_size_bytes) + else: + logger.warning(f"unsupported kv_layer type: {type(kv_layer)}") + + if not ptrs and not tensor_sizes: + continue + + base_ptrs.append(ptrs) + tensor_size_lists.append(tensor_sizes) + + self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) + self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) + + logger.info( + f"base_ptrs: {self.base_ptrs.shape}, tensor_size_lists: {self.tensor_size_lists.shape}" + ) + + +class UCMHMAConnector(UCMDirectConnector, SupportsHMA): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) + # group manager only lives on the scheduler side, where ``self._seed`` + # and ``self.request_hasher`` are populated by the parent ctor. + self.group_manager: Optional[KVCacheGroupManager] = None + if role == KVConnectorRole.SCHEDULER: + self.group_manager = KVCacheGroupManager( + kv_cache_config=kv_cache_config, + request_hasher=self.request_hasher, + base_seed=self._seed, + ) + lcm_block_size = self.group_manager.lcm_block_size + # Override the inherited ``block_size`` (which comes from + # ``cache_config.block_size``) so prefix accounting in this class + # is consistent with every group's block boundaries — vLLM's + # hybrid scheduler aligns ``num_computed_tokens`` to the LCM of + # all groups' block_size, and so do we. + self.block_size = lcm_block_size + self.hash_block_size = lcm_block_size + + logger.info( + f"UCMHMAConnector initialized with use_layerwise={self.use_layerwise}" + ) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.kv_caches = kv_caches + if current_platform.device_type == "npu" and self.use_compress: + self.kv_cache_layout = AscendDSV4Layout( + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, + ) + else: + self.kv_cache_layout = HMAKVCacheLayout( + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, + ) + self.store = self._create_store(self.kv_cache_layout) + self.block_data_size = self.kv_cache_layout.block_size + self.device = create_device() + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.group_manager is not None, ( + "get_num_new_matched_tokens must be called on the scheduler-side " + "connector, where the group manager is initialized." + ) + + lcm_block_size = self.group_manager.lcm_block_size + assert num_computed_tokens % lcm_block_size == 0, ( + f"num_computed_tokens={num_computed_tokens} is not aligned to " + f"lcm_block_size={lcm_block_size}" + ) + # ``hbm_hit_block_num`` and ``total_hit_block_num`` are tracked in + # LCM-block units in HMA mode; per-group block ids/counts are derived + # from these via each group's own block_size when needed. + hbm_hit_block_num = num_computed_tokens // lcm_block_size + + # Skip persistence if token count is below the threshold. + if self.persist_token_threshold > request.num_tokens: + logger.info_once( + f"Skip persistence: req {request.request_id}, " + f"input tokens ({request.num_tokens}) < threshold " + f"({self.persist_token_threshold})." + ) + return 0, False + + # Hash once per group so dump path can later reuse the same block ids. + group_ucm_block_ids = self.group_manager.compute_all_group_block_ids( + request.all_token_ids + ) + # Legacy ``ucm_block_ids`` mirrors the first full-attn group (by + # group_id order) for callers that still consume the flat list. + primary_full_attn = self.group_manager.full_attn_groups[0] + primary_block_ids = group_ucm_block_ids[primary_full_attn.group_id] + + external_hit_tokens, external_hit_lcm_blocks = ( + self.group_manager.lookup_external_hit_tokens( + num_computed_tokens, self.store, group_ucm_block_ids + ) + ) + + if ( + self.enable_record_traces + and request.request_id not in self.requests_meta + and len(primary_block_ids) > 0 + ): + hex_block_ids = [b.hex() for b in primary_block_ids] + logger.info_once( + f"timestamp: {time.perf_counter()}, " + f"input_length: {request.num_tokens}, " + f"output_length: {request.max_tokens}, " + f"ucm_block_ids: {hex_block_ids}" + ) + + total_hit_block_num = hbm_hit_block_num + external_hit_lcm_blocks + + logger.info_once( + f"request_id: {request.request_id}, " + f"total_lcm_blocks: {request.num_tokens // lcm_block_size}, " + f"hit hbm: {hbm_hit_block_num}, " + f"hit external: {external_hit_lcm_blocks}, " + f"total_tokens: {len(request.all_token_ids)}" + ) + if self.metrics_config and len(primary_block_ids) > 0: + ucmmetrics.update_stats( + { + "interval_lookup_hit_rates": external_hit_lcm_blocks + * lcm_block_size + / (len(primary_block_ids) * primary_full_attn.block_size) + }, + ) + + # When all the tokens are cached in ssd or hbm, we need to recompute + # the last token. This branch will be removed once vLLM scheduler + # provides a better solution in the future. + num_total_hit_tokens = total_hit_block_num * lcm_block_size + if num_total_hit_tokens == request.num_tokens and external_hit_tokens > 0: + external_hit_tokens -= 1 + + self.requests_meta[request.request_id] = HMARequestMeta( + ucm_block_ids=primary_block_ids, + hbm_hit_block_num=hbm_hit_block_num, + total_hit_block_num=total_hit_block_num, + num_token_ids=len(request.all_token_ids), + token_processed=num_total_hit_tokens, + group_ucm_block_ids=group_ucm_block_ids, + group_vllm_block_ids=[[] for _ in range(self.group_manager.num_groups)], + ) + + return external_hit_tokens, False + + def _generate_hma_dispatch_meta( + self, + req_meta: "HMARequestMeta", + new_tokens: int, + new_vllm_block_ids_per_group: tuple[list[int], ...], + need_load: bool = True, + ) -> RequestDispatchMeta: + """Build a flat (ucm, vllm) block id pair list across all groups. + + The output ``RequestDispatchMeta`` keeps the same shape as the + non-HMA path (``tuple[list[bytes], list[int]]``) so that + ``start_load_kv`` / ``wait_for_save`` and the underlying store APIs + do not need to know about groups. Per-group slices are concatenated + in ascending ``group_id`` order, with ``ucm_block_ids[k]`` and + ``vllm_block_ids[k]`` always referring to the same block. + + Layout per group within ``[token_processed, token_processed + new_tokens)``: + - **load** (only when ``external_hit_blocks > 0`` and ``need_load``): + - full-attn group: tokens ``[hbm_hit_tokens, total_hit_tokens)`` + - sliding-window group: the blocks covering tokens + ``[total_hit_tokens - sliding_window, total_hit_tokens)``; + ``start_blk`` is floor-rounded so when ``block_size > + sliding_window`` (e.g. on Ascend) we naturally load the single + last block. The SW window is reloaded every resume because + older blocks are evicted by the SW manager. + - **dump** of ``[token_processed, token_processed + new_tokens)``: + - full-attn group: every newly-completed full block (the + ``lookup_on_prefix`` chain needs every prefix block to be + present). + - sliding-window group: only the last + ``max(1, sliding_window/block_size)`` blocks before each LCM + boundary reached in this range. Lookup always resumes at LCM + boundaries and stage-2 SW check only inspects those tails, so + blocks between tails would be dead weight in the store. + """ + assert self.group_manager is not None + groups_by_id = self.group_manager.groups_by_id + num_groups = self.group_manager.num_groups + lcm_block_size = self.group_manager.lcm_block_size + + assert len(new_vllm_block_ids_per_group) == num_groups, ( + f"new_vllm_block_ids_per_group length " + f"{len(new_vllm_block_ids_per_group)} does not match " + f"num_groups {num_groups}" + ) + for gid in range(num_groups): + req_meta.group_vllm_block_ids[gid].extend(new_vllm_block_ids_per_group[gid]) + + load_ucm_block_ids: list[bytes] = [] + load_vllm_block_ids: list[int] = [] + dump_ucm_block_ids: list[bytes] = [] + dump_vllm_block_ids: list[int] = [] + + external_hit_lcm_blocks = ( + req_meta.total_hit_block_num - req_meta.hbm_hit_block_num + ) + hbm_hit_tokens = req_meta.hbm_hit_block_num * lcm_block_size + total_hit_tokens = req_meta.total_hit_block_num * lcm_block_size + + if need_load and external_hit_lcm_blocks > 0: + for gid, group in enumerate(groups_by_id): + if group.is_full_attention: + load_tok_start = hbm_hit_tokens + else: + load_tok_start = total_hit_tokens - group.sliding_window + load_tok_end = total_hit_tokens + start_blk = load_tok_start // group.block_size + end_blk = load_tok_end // group.block_size + if start_blk >= end_blk: + continue + load_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + load_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + + if req_meta.token_processed < req_meta.num_token_ids: + dump_tok_start = req_meta.token_processed + dump_tok_end = min( + req_meta.token_processed + new_tokens, req_meta.num_token_ids + ) + # LCM boundaries B with ``dump_tok_start < B <= dump_tok_end``. + # SW groups only need the tail at these boundaries because lookup + # always resumes at LCM boundaries (see + # ``lookup_external_hit_tokens`` stage 2). + first_lcm_b = (dump_tok_start // lcm_block_size + 1) * lcm_block_size + last_lcm_b = (dump_tok_end // lcm_block_size) * lcm_block_size + + for gid, group in enumerate(groups_by_id): + if group.is_full_attention: + # Dump every newly completed block: ``lookup_on_prefix`` + # walks the full prefix chain so any gap would truncate + # future hits. + start_blk = dump_tok_start // group.block_size + end_blk = dump_tok_end // group.block_size + if start_blk >= end_blk: + continue + dump_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + dump_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + else: + # Dump only the tail blocks at each LCM boundary reached + # in this range. Since ``sliding_window <= + # lcm_block_size`` (validated in ``KVCacheGroupManager``), + # consecutive boundaries' tails do not overlap and we can + # extend the lists without dedup. + if first_lcm_b > last_lcm_b: + continue + tail_count = max(1, group.sliding_window // group.block_size) + b = first_lcm_b + while b <= last_lcm_b: + end_blk = b // group.block_size + start_blk = max(0, end_blk - tail_count) + if start_blk < end_blk: + dump_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + dump_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + b += lcm_block_size + req_meta.token_processed += new_tokens + + return RequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + assert self.group_manager is not None + num_groups = self.group_manager.num_groups + empty_per_group: tuple[list[int], ...] = tuple([] for _ in range(num_groups)) + + requests_dispatch_meta: dict[str, RequestDispatchMeta] = {} + + for request in scheduler_output.scheduled_new_reqs: + request_id = request.req_id + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.block_ids, + ) + + # Same three situations as the parent: chunked prefill (dump only), + # resumed (load + dump), decode (no-op). + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + raw_new_block_ids = scheduled_cached_reqs.new_block_ids[i] + new_block_ids = ( + empty_per_group if raw_new_block_ids is None else raw_new_block_ids + ) + if hasattr(scheduled_cached_reqs, "resumed_from_preemption"): + resumed_from_preemption = ( + scheduled_cached_reqs.resumed_from_preemption[i] + ) + else: + resumed_from_preemption = ( + request_id in scheduled_cached_reqs.resumed_req_ids + ) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + new_block_ids, + resumed_from_preemption, + ) + else: + for request in scheduled_cached_reqs: + request_id = request.req_id + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids, + request.resumed_from_preemption, + ) + + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + + return UCMConnectorMetadata(requests_dispatch_meta) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return False, None + + +class UCMConnector(KVConnectorBase_V1, SupportsHMA): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.connector: KVConnectorBase_V1 ucm_config = Config(vllm_config.kv_transfer_config) self.launch_config = ucm_config.get_config() @@ -1189,6 +2039,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): > 1 ) + use_hma = ( + self._vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False + or os.getenv("USE_MULTI_GROUPS_KV_CACHE") == "1" + ) + if use_lite: self.connector = UCMLiteConnector(vllm_config, role) elif use_ratio_rate: @@ -1197,8 +2052,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.connector = UCMCPConnector(vllm_config, role) elif use_layerwise: self.connector = UCMLayerWiseConnector(vllm_config, role) + elif use_hma: + self.connector = UCMHMAConnector(vllm_config, role, kv_cache_config) else: - self.connector = UCMDirectConnector(vllm_config, role) + self.connector = UCMDirectConnector(vllm_config, role, kv_cache_config) def get_num_new_matched_tokens( self, @@ -1350,3 +2207,10 @@ def get_block_ids_with_load_errors(self) -> set[int]: Empty set if no load errors occurred. """ return self.connector.get_block_ids_with_load_errors() + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return self.connector.request_finished_all_groups(request, block_ids)