From a9bb5488d44617100e012238834a6294cf9b4c9e Mon Sep 17 00:00:00 2001 From: qyh111 Date: Thu, 30 Apr 2026 06:41:08 +0000 Subject: [PATCH 1/2] Adapt deepseek v4 --- ucm/integration/vllm/ucm_connector.py | 680 +++++++++++++++++++++++++- 1 file changed, 670 insertions(+), 10 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index e89fc1f4b..bf7268457 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -6,7 +6,7 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple import numpy as np import torch @@ -15,6 +15,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.parallel_state import get_world_group from vllm.model_executor.models.utils import extract_layer_index @@ -28,6 +29,14 @@ from ucm.store.factory_v1 import UcmConnectorFactoryV1 from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1 from ucm.utils import Config +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheSpec, + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -51,6 +60,31 @@ class RequestMeta: token_processed: int = 0 +@dataclass +class HMARequestMeta(RequestMeta): + """RequestMeta extended with per-group block tracking for hybrid models. + + The inherited fields (``ucm_block_ids``, ``hbm_hit_block_num``, + ``total_hit_block_num``, ``num_token_ids``, ``vllm_block_ids``, + ``token_processed``) keep their original semantics and mirror the + full-attention group exactly, so dispatch/load/save paths inherited from + :class:`UCMDirectConnector` keep working. + + The two new fields are 2D lists indexed by the original + ``kv_cache_config.kv_cache_groups`` order (i.e. ``[group_id]``): + - ``group_ucm_block_ids[gid]``: full block hashes obtained by hashing + ``request.all_token_ids`` with group ``gid``'s own block size and + chain seed. ``group_ucm_block_ids[full_attn_group_id]`` equals the + inherited ``ucm_block_ids``. + - ``group_vllm_block_ids[gid]``: per-group VLLM physical block ids; this + is initialised as an empty list per group here and populated later by + the dispatch path (still a TODO for HMA dump/load). + """ + + group_ucm_block_ids: list[list[bytes]] = field(default_factory=list) + group_vllm_block_ids: list[list[int]] = field(default_factory=list) + + @dataclass class RequestDispatchMeta: load_block_ids: tuple[ @@ -61,12 +95,13 @@ class RequestDispatchMeta: class KVCacheLayout: def __init__( - self, kvcaches, use_layerwise: bool, vllm_config: "VllmConfig" + self, kvcaches, ucm_config: dict, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig" ) -> None: # each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index) self.base_ptrs: np.ndarray # (n_layers, n_ptrs) self.tensor_size_lists: np.ndarray # (n_layers, n_tensor_sizes) - self.use_layerwise = use_layerwise + self.use_layerwise = ucm_config.get("use_layerwise", False) + self.kv_cache_config = kv_cache_config self.vllm_config = vllm_config self.pp_size = self.vllm_config.parallel_config.pipeline_parallel_size self.num_hidden_layers = getattr( @@ -192,8 +227,8 @@ class UCMDirectConnector(KVConnectorBase_V1): load -> forward -> save """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig"): + super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) self.use_layerwise = False self.kv_caches: dict[str, torch.Tensor] = {} self.local_rank = ( @@ -212,6 +247,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self.head_size = vllm_config.model_config.get_head_size() self.element_size = vllm_config.model_config.dtype.itemsize + self._kv_cache_config = kv_cache_config if current_platform.is_cuda_alike(): logger.info("CUDA device is available.") @@ -365,7 +401,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for i, tensor in enumerate(sample_kv_layer): logger.info(f"kv cache shape {i}: {tensor.shape}") self.kv_cache_layout = KVCacheLayout( - self.kv_caches, self.use_layerwise, self._vllm_config + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, ) self.block_data_size = self.kv_cache_layout.block_size self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id @@ -921,6 +960,12 @@ def wait_for_save(self) -> None: if self.enable_event_sync: self.device.destroy_event_handles() + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return False, None class UCMCPConnector(UCMLayerWiseConnector): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): @@ -1148,10 +1193,614 @@ def get_num_new_matched_tokens(self, request, num_computed_tokens): ) return 0, False +def layer_name_to_kv_cache_spec( + kv_cache_config: KVCacheConfig, +) -> dict[str, KVCacheSpec]: + """Map each model layer name to its concrete KVCacheSpec. -class UCMConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + Handles merged group specs and UniformTypeKVCacheSpecs (per-layer + ``kv_cache_specs`` entries). + """ + out: dict[str, KVCacheSpec] = {} + for group in kv_cache_config.kv_cache_groups: + spec = group.kv_cache_spec + if isinstance(spec, UniformTypeKVCacheSpecs): + by_name = spec.kv_cache_specs + for name in group.layer_names: + out[name] = by_name[name] + else: + for name in group.layer_names: + out[name] = spec + return out + + + +def block_size_from_kv_cache_spec(spec: KVCacheSpec) -> int: + """Token block size used for KV scheduling / hashing for one group spec.""" + if isinstance(spec, UniformTypeKVCacheSpecs): + return next(iter(spec.kv_cache_specs.values())).block_size + return spec.block_size + + +def sliding_window_from_kv_cache_spec(spec: KVCacheSpec) -> Optional[int]: + """Return the sliding window size of a group spec, or None for full attention. + + A group is treated as full attention iff its sample spec exposes no + ``sliding_window`` attribute or that attribute is ``None``. + """ + if isinstance(spec, UniformTypeKVCacheSpecs): + sample = next(iter(spec.kv_cache_specs.values())) + return getattr(sample, "sliding_window", None) + return getattr(spec, "sliding_window", None) + + +@dataclass +class GroupInfo: + """Per-group metadata used by :class:`KVCacheGroupManager`.""" + + group_id: int + block_size: int + # None for full-attention groups, otherwise the window length in tokens. + sliding_window: Optional[int] + layer_names: tuple[str, ...] + # Independent hash chain seed per group (see ``KVCacheGroupManager``). + seed: bytes + + @property + def is_full_attention(self) -> bool: + return self.sliding_window is None + + +class KVCacheGroupManager: + """Group-aware hashing and lookup for hybrid (HMA) connectors. + + Splits ``kv_cache_config.kv_cache_groups`` into the full-attention group + (exactly one) and sliding-window groups, derives a per-group hash chain + seed, and exposes a two-stage lookup that: + + 1. Hashes ``request.all_token_ids`` with the full-attention group's block + size and runs ``store.lookup_on_prefix`` on the blocks beyond + ``hbm_hit_block_num``. + 2. For each sliding-window group, re-hashes the same prefix with that + group's own block size and verifies the last + ``sliding_window // block_size`` blocks all exist via ``store.lookup``. + If any sliding-window group fails this check, the whole external hit + is downgraded to zero. + """ + + def __init__( + self, + kv_cache_config: "KVCacheConfig", + request_hasher: "RequestHasher", + base_seed: bytes, + ) -> None: + self.request_hasher = request_hasher + # Indexed by original group_id; positions match + # ``kv_cache_config.kv_cache_groups``. + self.groups_by_id: list[GroupInfo] = [] + full_attn_groups: list[GroupInfo] = [] + self.sliding_window_groups: list[GroupInfo] = [] + + for group_id, group in enumerate(kv_cache_config.kv_cache_groups): + spec = group.kv_cache_spec + block_size = block_size_from_kv_cache_spec(spec) + sliding_window = sliding_window_from_kv_cache_spec(spec) + # Mix group_id into the hash chain seed so two groups with the + # same block_size do not collide in the underlying store. + seed = request_hasher((b"UCM_GROUP_SEED", base_seed, group_id)) + info = GroupInfo( + group_id=group_id, + block_size=block_size, + sliding_window=sliding_window, + layer_names=tuple(group.layer_names), + seed=seed, + ) + self.groups_by_id.append(info) + if info.is_full_attention: + full_attn_groups.append(info) + else: + self.sliding_window_groups.append(info) + + assert len(full_attn_groups) == 1, ( + f"UCMHMAConnector expects exactly one full-attention group, got " + f"{len(full_attn_groups)}: " + f"{[(g.group_id, g.block_size) for g in full_attn_groups]}" + ) + self.full_attn_group: GroupInfo = full_attn_groups[0] + + # Sliding-window groups whose block_size does not divide the + # full-attn group's block_size cannot align tail checks with full-attn + # block boundaries; reject early to surface configuration mistakes. + full_attn_block_size = self.full_attn_group.block_size + for sw in self.sliding_window_groups: + if full_attn_block_size % sw.block_size != 0: + raise ValueError( + f"Sliding window group {sw.group_id} block_size=" + f"{sw.block_size} does not divide full-attn block_size=" + f"{full_attn_block_size}." + ) + if sw.sliding_window % sw.block_size != 0: + raise ValueError( + f"Sliding window group {sw.group_id} sliding_window=" + f"{sw.sliding_window} is not a multiple of block_size=" + f"{sw.block_size}." + ) + + logger.info( + "KVCacheGroupManager initialized: " + f"full_attn_group=({self.full_attn_group.group_id}, " + f"{self.full_attn_group.block_size}), " + f"sliding_window_groups=" + f"{[(g.group_id, g.block_size, g.sliding_window) for g in self.sliding_window_groups]}" + ) + + @property + def num_groups(self) -> int: + return len(self.groups_by_id) + + def compute_block_hashes( + self, group: GroupInfo, token_ids: list[int] + ) -> list[bytes]: + """Hash ``token_ids`` into per-block ids using ``group``'s chain seed.""" + ret: list[bytes] = [] + parent = group.seed + block_size = group.block_size + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + if len(block_token_ids) < block_size: + break + hash_value = self.request_hasher((parent, tuple(block_token_ids))) + parent = hash_value + ret.append(hash_value) + return ret + + def compute_all_group_block_ids( + self, token_ids: list[int] + ) -> list[list[bytes]]: + """Compute full block hashes for every group, indexed by group_id. + + ``ret[gid]`` covers all aligned blocks of ``token_ids`` using group + ``gid``'s ``block_size`` and chain seed. The trailing partial block + (if any) is dropped, matching :meth:`compute_block_hashes`. + """ + return [ + self.compute_block_hashes(g, token_ids) for g in self.groups_by_id + ] + + def lookup_external_hit_tokens( + self, + hbm_hit_block_num: int, + store: "UcmKVStoreBaseV1", + group_block_ids: list[list[bytes]], + ) -> tuple[int, int]: + """Two-stage HMA lookup using precomputed per-group hashes. + + ``group_block_ids`` must have one entry per group, indexed by the + original ``group_id`` (see :meth:`compute_all_group_block_ids`). + + Returns: + Tuple of + - ``external_hit_tokens``: tokens hit beyond ``hbm_hit_block_num``, + aligned to the full-attn group's block size. ``0`` if any + sliding-window group fails its tail check. + - ``external_hit_blocks``: number of full-attn blocks hit beyond + ``hbm_hit_block_num`` (also ``0`` on downgrade). + """ + assert len(group_block_ids) == self.num_groups, ( + f"group_block_ids length {len(group_block_ids)} does not match " + f"num_groups {self.num_groups}" + ) + + full_attn = self.full_attn_group + full_attn_block_ids = group_block_ids[full_attn.group_id] + + external_block_ids = full_attn_block_ids[hbm_hit_block_num:] + if not external_block_ids: + return 0, 0 + + try: + external_hit_blocks = store.lookup_on_prefix(external_block_ids) + 1 + except Exception as e: + logger.error( + f"full-attn group {full_attn.group_id} lookup error. " + f"{type(e).__name__}: {e}" + ) + return 0, 0 + + if external_hit_blocks <= 0: + return 0, 0 + + external_hit_tokens = external_hit_blocks * full_attn.block_size + # Resume boundary: SW layers need ``[total_hit_tokens - sliding_window, + # total_hit_tokens)`` to be present in external storage, regardless of + # whether some of those tokens are also covered by ``hbm_hit_tokens`` + # (the SW manager has dropped them anyway when resuming). + total_hit_tokens = ( + hbm_hit_block_num * full_attn.block_size + external_hit_tokens + ) + + for sw in self.sliding_window_groups: + # Q3: not enough hit to fill a full sliding window — treat as miss. + if total_hit_tokens < sw.sliding_window: + logger.info( + f"sliding window group {sw.group_id} tail check skipped: " + f"total_hit_tokens={total_hit_tokens} < " + f"sliding_window={sw.sliding_window}, downgrade to 0." + ) + return 0, 0 + + sw_block_ids = group_block_ids[sw.group_id] + # ``sw.block_size`` divides ``total_hit_tokens`` (block_size + # divides full_attn.block_size, validated in __init__), so this + # slice is exact and never crosses the group's last full block. + num_blocks_in_total_hit = total_hit_tokens // sw.block_size + tail_count = sw.sliding_window // sw.block_size + tail_block_ids = sw_block_ids[ + num_blocks_in_total_hit - tail_count : num_blocks_in_total_hit + ] + try: + results = store.lookup(tail_block_ids) + except Exception as e: + logger.error( + f"sliding window group {sw.group_id} lookup error. " + f"{type(e).__name__}: {e}" + ) + return 0, 0 + if not all(results): + logger.info( + f"sliding window group {sw.group_id} tail miss: " + f"hits={results}, downgrade external hit to 0." + ) + return 0, 0 + + return external_hit_tokens, external_hit_blocks + + +class HMAKVCacheLayout(KVCacheLayout): + def __init__( + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + ): + self.layer_name_to_kv_cache_spec = layer_name_to_kv_cache_spec( + kv_cache_config + ) + super().__init__(kvcaches, ucm_config, vllm_config, kv_cache_config) + + def _build_layout(self, kvcaches): + base_ptrs = [] + tensor_size_lists = [] + + for raw_tensor in self.kv_cache_config.kv_cache_tensors: + ptrs = [] + tensor_sizes = [] + + if raw_tensor.shared_by: + sample_layer_name = raw_tensor.shared_by[0] + kv_layer = kvcaches.get(sample_layer_name) + if kv_layer is None: + logger.warning(f"kv_layer {sample_layer_name} not found in kvcaches") + continue + kv_cache_spec = self.layer_name_to_kv_cache_spec[sample_layer_name] + if isinstance(kv_layer, torch.Tensor): + ptrs.append(kv_layer.data_ptr()) + tensor_sizes.append(kv_cache_spec.page_size_bytes) + elif isinstance(kv_layer, (tuple, list)): + ptrs.append(kv_layer[0].data_ptr()) + tensor_sizes.append(kv_cache_spec.page_size_bytes) + else: + logger.warning(f"unsupported kv_layer type: {type(kv_layer)}") + + if not ptrs and not tensor_sizes: + continue + + base_ptrs.append(ptrs) + tensor_size_lists.append(tensor_sizes) + + self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) + self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) + + logger.info( + f"base_ptrs: {self.base_ptrs.shape}, tensor_size_lists: {self.tensor_size_lists.shape}" + ) + + +class UCMHMAConnector(UCMDirectConnector, SupportsHMA): + def __init__( + self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig" + ): + super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) + # group manager only lives on the scheduler side, where ``self._seed`` + # and ``self.request_hasher`` are populated by the parent ctor. + self.group_manager: Optional[KVCacheGroupManager] = None + if role == KVConnectorRole.SCHEDULER: + self.group_manager = KVCacheGroupManager( + kv_cache_config=kv_cache_config, + request_hasher=self.request_hasher, + base_seed=self._seed, + ) + full_attn_block_size = self.group_manager.full_attn_group.block_size + # Override the inherited ``block_size`` (which comes from + # ``cache_config.block_size``) so prefix accounting in this class + # is consistent with the full-attn group. + self.block_size = full_attn_block_size + self.hash_block_size = full_attn_block_size + + logger.info(f"UCMHMAConnector initialized with use_layerwise={self.use_layerwise}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.kv_caches = kv_caches + self.kv_cache_layout = HMAKVCacheLayout( + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, + ) + self.store = self._create_store(self.kv_cache_layout) + self.block_data_size = self.kv_cache_layout.block_size + self.device = create_device() + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.group_manager is not None, ( + "get_num_new_matched_tokens must be called on the scheduler-side " + "connector, where the group manager is initialized." + ) + + full_attn = self.group_manager.full_attn_group + full_attn_block_size = full_attn.block_size + assert num_computed_tokens % full_attn_block_size == 0, ( + f"num_computed_tokens={num_computed_tokens} is not aligned to " + f"full-attn group block_size={full_attn_block_size}" + ) + hbm_hit_block_num = num_computed_tokens // full_attn_block_size + + # Skip persistence if token count is below the threshold. + if self.persist_token_threshold > request.num_tokens: + logger.info_once( + f"Skip persistence: req {request.request_id}, " + f"input tokens ({request.num_tokens}) < threshold " + f"({self.persist_token_threshold})." + ) + return 0, False + + # Hash once per group so dump path can later reuse the same block ids. + group_ucm_block_ids = self.group_manager.compute_all_group_block_ids( + request.all_token_ids + ) + full_attn_block_ids = group_ucm_block_ids[full_attn.group_id] + + external_hit_tokens, external_hit_blocks = ( + self.group_manager.lookup_external_hit_tokens( + hbm_hit_block_num, self.store, group_ucm_block_ids + ) + ) + + if ( + self.enable_record_traces + and request.request_id not in self.requests_meta + and len(full_attn_block_ids) > 0 + ): + hex_block_ids = [b.hex() for b in full_attn_block_ids] + logger.info_once( + f"timestamp: {time.perf_counter()}, " + f"input_length: {request.num_tokens}, " + f"output_length: {request.max_tokens}, " + f"ucm_block_ids: {hex_block_ids}" + ) + + total_hit_block_num = hbm_hit_block_num + external_hit_blocks + + logger.info_once( + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(full_attn_block_ids)}, " + f"hit hbm: {hbm_hit_block_num}, " + f"hit external: {external_hit_blocks}" + ) + if self.metrics_config and len(full_attn_block_ids) > 0: + ucmmetrics.update_stats( + { + "interval_lookup_hit_rates": external_hit_blocks + / len(full_attn_block_ids) + }, + ) + + # When all the tokens are cached in ssd or hbm, we need to recompute + # the last token. This branch will be removed once vLLM scheduler + # provides a better solution in the future. + num_total_hit_tokens = total_hit_block_num * full_attn_block_size + if num_total_hit_tokens == request.num_tokens and external_hit_tokens > 0: + external_hit_tokens -= 1 + + self.requests_meta[request.request_id] = HMARequestMeta( + ucm_block_ids=full_attn_block_ids, + hbm_hit_block_num=hbm_hit_block_num, + total_hit_block_num=total_hit_block_num, + num_token_ids=len(request.all_token_ids), + token_processed=num_total_hit_tokens, + group_ucm_block_ids=group_ucm_block_ids, + group_vllm_block_ids=[[] for _ in range(self.group_manager.num_groups)], + ) + + return external_hit_tokens, False + + def _generate_hma_dispatch_meta( + self, + req_meta: "HMARequestMeta", + new_tokens: int, + new_vllm_block_ids_per_group: tuple[list[int], ...], + need_load: bool = True, + ) -> RequestDispatchMeta: + """Build a flat (ucm, vllm) block id pair list across all groups. + + The output ``RequestDispatchMeta`` keeps the same shape as the + non-HMA path (``tuple[list[bytes], list[int]]``) so that + ``start_load_kv`` / ``wait_for_save`` and the underlying store APIs + do not need to know about groups. Per-group slices are concatenated + in ascending ``group_id`` order, with ``ucm_block_ids[k]`` and + ``vllm_block_ids[k]`` always referring to the same block. + + Layout per group within ``[token_processed, token_processed + new_tokens)``: + - **load** (only when ``external_hit_blocks > 0`` and ``need_load``): + - full-attn group: tokens ``[hbm_hit_tokens, total_hit_tokens)`` + - sliding-window group: tokens + ``[total_hit_tokens - sliding_window, total_hit_tokens)`` + (the SW window is reloaded every resume because older blocks are + evicted by the SW manager). + - **dump**: every newly-completed full block of every group inside + ``[token_processed, token_processed + new_tokens)`` (Option A: SW + groups dump *all* blocks, not only the tail, so future requests + can hit at any full-attn boundary). + """ + assert self.group_manager is not None + groups_by_id = self.group_manager.groups_by_id + num_groups = self.group_manager.num_groups + full_attn_bs = self.group_manager.full_attn_group.block_size + + assert len(new_vllm_block_ids_per_group) == num_groups, ( + f"new_vllm_block_ids_per_group length " + f"{len(new_vllm_block_ids_per_group)} does not match " + f"num_groups {num_groups}" + ) + for gid in range(num_groups): + req_meta.group_vllm_block_ids[gid].extend( + new_vllm_block_ids_per_group[gid] + ) + + load_ucm_block_ids: list[bytes] = [] + load_vllm_block_ids: list[int] = [] + dump_ucm_block_ids: list[bytes] = [] + dump_vllm_block_ids: list[int] = [] + + external_hit_blocks = ( + req_meta.total_hit_block_num - req_meta.hbm_hit_block_num + ) + hbm_hit_tokens = req_meta.hbm_hit_block_num * full_attn_bs + total_hit_tokens = req_meta.total_hit_block_num * full_attn_bs + + if need_load and external_hit_blocks > 0: + for gid, group in enumerate(groups_by_id): + if group.is_full_attention: + load_tok_start = hbm_hit_tokens + else: + load_tok_start = total_hit_tokens - group.sliding_window + load_tok_end = total_hit_tokens + start_blk = load_tok_start // group.block_size + end_blk = load_tok_end // group.block_size + if start_blk >= end_blk: + continue + load_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + load_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + + if req_meta.token_processed < req_meta.num_token_ids: + dump_tok_start = req_meta.token_processed + dump_tok_end = min( + req_meta.token_processed + new_tokens, req_meta.num_token_ids + ) + for gid, group in enumerate(groups_by_id): + start_blk = dump_tok_start // group.block_size + end_blk = dump_tok_end // group.block_size + if start_blk >= end_blk: + continue + dump_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + dump_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + req_meta.token_processed += new_tokens + + return RequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + assert self.group_manager is not None + num_groups = self.group_manager.num_groups + empty_per_group: tuple[list[int], ...] = tuple([] for _ in range(num_groups)) + + requests_dispatch_meta: dict[str, RequestDispatchMeta] = {} + + for request in scheduler_output.scheduled_new_reqs: + request_id = request.req_id + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.block_ids, + ) + + # Same three situations as the parent: chunked prefill (dump only), + # resumed (load + dump), decode (no-op). + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + raw_new_block_ids = scheduled_cached_reqs.new_block_ids[i] + new_block_ids = ( + empty_per_group if raw_new_block_ids is None else raw_new_block_ids + ) + if hasattr(scheduled_cached_reqs, "resumed_from_preemption"): + resumed_from_preemption = ( + scheduled_cached_reqs.resumed_from_preemption[i] + ) + else: + resumed_from_preemption = ( + request_id in scheduled_cached_reqs.resumed_req_ids + ) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + new_block_ids, + resumed_from_preemption, + ) + else: + for request in scheduled_cached_reqs: + request_id = request.req_id + req_meta = self.requests_meta.get(request_id) + if req_meta is None: + continue + assert isinstance(req_meta, HMARequestMeta) + requests_dispatch_meta[request_id] = self._generate_hma_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids, + request.resumed_from_preemption, + ) + + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + + return UCMConnectorMetadata(requests_dispatch_meta) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return False, None + + +class UCMConnector(KVConnectorBase_V1, SupportsHMA): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig"): + super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) self.connector: KVConnectorBase_V1 ucm_config = Config(vllm_config.kv_transfer_config) self.launch_config = ucm_config.get_config() @@ -1189,6 +1838,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): > 1 ) + use_hma = self._vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False + if use_lite: self.connector = UCMLiteConnector(vllm_config, role) elif use_ratio_rate: @@ -1197,8 +1848,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.connector = UCMCPConnector(vllm_config, role) elif use_layerwise: self.connector = UCMLayerWiseConnector(vllm_config, role) + elif use_hma: + self.connector = UCMHMAConnector(vllm_config, role, kv_cache_config) else: - self.connector = UCMDirectConnector(vllm_config, role) + self.connector = UCMDirectConnector(vllm_config, role, kv_cache_config) def get_num_new_matched_tokens( self, @@ -1350,3 +2003,10 @@ def get_block_ids_with_load_errors(self) -> set[int]: Empty set if no load errors occurred. """ return self.connector.get_block_ids_with_load_errors() + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return self.connector.request_finished_all_groups(request, block_ids) \ No newline at end of file From d80e6016025052eb616c295c5b9240e1baf94df7 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Sat, 9 May 2026 06:33:23 +0000 Subject: [PATCH 2/2] [Feat]Adapt Deepseek-V4-Flash on ascend and cuda --- ucm/integration/vllm/ucm_connector.py | 520 ++++++++++++++++++-------- 1 file changed, 362 insertions(+), 158 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index bf7268457..0cd480b8d 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -21,6 +21,14 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from ucm.integration.vllm.device import create_device from ucm.logger import init_logger @@ -29,14 +37,6 @@ from ucm.store.factory_v1 import UcmConnectorFactoryV1 from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1 from ucm.utils import Config -from vllm.v1.kv_cache_interface import ( - KVCacheConfig, - KVCacheSpec, - FullAttentionSpec, - MambaSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs, -) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -95,7 +95,11 @@ class RequestDispatchMeta: class KVCacheLayout: def __init__( - self, kvcaches, ucm_config: dict, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig" + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", ) -> None: # each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index) self.base_ptrs: np.ndarray # (n_layers, n_ptrs) @@ -113,6 +117,8 @@ def __init__( name: extract_layer_index(name) for name in kvcaches.keys() } self.first_layer_id = next(iter(self.layer_name_to_id.values())) + self.num_blocks = self.kv_cache_config.num_blocks + self.layer_name_to_kv_cache_spec = layer_name_to_kv_cache_spec(kv_cache_config) self._build_layout(kvcaches) def _build_layout(self, kvcaches): @@ -227,8 +233,15 @@ class UCMDirectConnector(KVConnectorBase_V1): load -> forward -> save """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig"): - super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.use_layerwise = False self.kv_caches: dict[str, torch.Tensor] = {} self.local_rank = ( @@ -247,6 +260,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_co ) self.head_size = vllm_config.model_config.get_head_size() self.element_size = vllm_config.model_config.dtype.itemsize + self.use_compress = hasattr( + self._vllm_config.model_config.hf_config, "compress_ratios" + ) self._kv_cache_config = kv_cache_config if current_platform.is_cuda_alike(): @@ -967,6 +983,7 @@ def request_finished_all_groups( ) -> tuple[bool, dict[str, Any] | None]: return False, None + class UCMCPConnector(UCMLayerWiseConnector): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config, role) @@ -1193,33 +1210,40 @@ def get_num_new_matched_tokens(self, request, num_computed_tokens): ) return 0, False + def layer_name_to_kv_cache_spec( kv_cache_config: KVCacheConfig, -) -> dict[str, KVCacheSpec]: +) -> dict[str, list[KVCacheSpec]]: """Map each model layer name to its concrete KVCacheSpec. Handles merged group specs and UniformTypeKVCacheSpecs (per-layer ``kv_cache_specs`` entries). """ - out: dict[str, KVCacheSpec] = {} + out: dict[str, list[KVCacheSpec]] = defaultdict(list) for group in kv_cache_config.kv_cache_groups: spec = group.kv_cache_spec if isinstance(spec, UniformTypeKVCacheSpecs): by_name = spec.kv_cache_specs for name in group.layer_names: - out[name] = by_name[name] + out[name].append(by_name[name]) else: for name in group.layer_names: - out[name] = spec + out[name].append(spec) return out - def block_size_from_kv_cache_spec(spec: KVCacheSpec) -> int: """Token block size used for KV scheduling / hashing for one group spec.""" + block_size = 0 if isinstance(spec, UniformTypeKVCacheSpecs): - return next(iter(spec.kv_cache_specs.values())).block_size - return spec.block_size + block_size = next(iter(spec.kv_cache_specs.values())).block_size + else: + block_size = spec.block_size + + if current_platform.device_type == "npu" and hasattr(spec, "compress_ratio"): + block_size *= spec.compress_ratio + + return block_size def sliding_window_from_kv_cache_spec(spec: KVCacheSpec) -> Optional[int]: @@ -1254,18 +1278,22 @@ def is_full_attention(self) -> bool: class KVCacheGroupManager: """Group-aware hashing and lookup for hybrid (HMA) connectors. - Splits ``kv_cache_config.kv_cache_groups`` into the full-attention group - (exactly one) and sliding-window groups, derives a per-group hash chain + Splits ``kv_cache_config.kv_cache_groups`` into full-attention groups + (one or more) and sliding-window groups, derives a per-group hash chain seed, and exposes a two-stage lookup that: - 1. Hashes ``request.all_token_ids`` with the full-attention group's block - size and runs ``store.lookup_on_prefix`` on the blocks beyond - ``hbm_hit_block_num``. + 1. For every full-attention group, hashes ``request.all_token_ids`` with + that group's block size and runs ``store.lookup_on_prefix`` on the + blocks beyond its own ``hbm_hit_block_num``. The candidate hits (in + tokens) are min'd across full-attn groups and rounded down to + ``lcm_block_size``. 2. For each sliding-window group, re-hashes the same prefix with that group's own block size and verifies the last - ``sliding_window // block_size`` blocks all exist via ``store.lookup``. - If any sliding-window group fails this check, the whole external hit - is downgraded to zero. + ``max(1, sliding_window // block_size)`` blocks all exist via + ``store.lookup`` (when ``block_size > sliding_window`` — e.g. on + Ascend — the last single block already covers the SW). If any + sliding-window group fails this check, the whole external hit is + downgraded to zero. """ def __init__( @@ -1278,7 +1306,8 @@ def __init__( # Indexed by original group_id; positions match # ``kv_cache_config.kv_cache_groups``. self.groups_by_id: list[GroupInfo] = [] - full_attn_groups: list[GroupInfo] = [] + # All groups whose spec has no sliding_window. Order follows group_id. + self.full_attn_groups: list[GroupInfo] = [] self.sliding_window_groups: list[GroupInfo] = [] for group_id, group in enumerate(kv_cache_config.kv_cache_groups): @@ -1297,28 +1326,44 @@ def __init__( ) self.groups_by_id.append(info) if info.is_full_attention: - full_attn_groups.append(info) + self.full_attn_groups.append(info) else: self.sliding_window_groups.append(info) - assert len(full_attn_groups) == 1, ( - f"UCMHMAConnector expects exactly one full-attention group, got " - f"{len(full_attn_groups)}: " - f"{[(g.group_id, g.block_size) for g in full_attn_groups]}" + assert len(self.full_attn_groups) >= 1, ( + "UCMHMAConnector expects at least one full-attention group in " + "kv_cache_config.kv_cache_groups." ) - self.full_attn_group: GroupInfo = full_attn_groups[0] - # Sliding-window groups whose block_size does not divide the - # full-attn group's block_size cannot align tail checks with full-attn - # block boundaries; reject early to surface configuration mistakes. - full_attn_block_size = self.full_attn_group.block_size + # Resume points must be aligned to the LCM of every group's + # block_size so that per-group block accounting (including each + # full-attn group's lookup result and every SW group's tail slice) + # lands on a clean block boundary. + all_block_sizes = [g.block_size for g in self.groups_by_id] + self.lcm_block_size: int = math.lcm(*all_block_sizes) + + for g in self.groups_by_id: + assert self.lcm_block_size % g.block_size == 0, ( + f"group {g.group_id} block_size={g.block_size} does not " + f"divide LCM={self.lcm_block_size}" + ) for sw in self.sliding_window_groups: - if full_attn_block_size % sw.block_size != 0: - raise ValueError( - f"Sliding window group {sw.group_id} block_size=" - f"{sw.block_size} does not divide full-attn block_size=" - f"{full_attn_block_size}." - ) + # The dump path stores only ``[B - sliding_window, B)`` for each + # LCM boundary B. Requiring ``sliding_window <= lcm_block_size`` + # guarantees consecutive boundaries' tails do not overlap, so the + # incremental dump can append each boundary's tail without + # cross-boundary deduplication. + assert sw.sliding_window <= self.lcm_block_size, ( + f"sliding window group {sw.group_id} sliding_window=" + f"{sw.sliding_window} > lcm_block_size=" + f"{self.lcm_block_size}; not supported." + ) + # On some backends (e.g. Ascend) ``block_size`` can exceed + # ``sliding_window``; in that case a single block already holds + # more than a window of tokens, so we will treat the last block + # as the SW tail and skip the divisibility check. + if sw.block_size >= sw.sliding_window: + continue if sw.sliding_window % sw.block_size != 0: raise ValueError( f"Sliding window group {sw.group_id} sliding_window=" @@ -1328,8 +1373,9 @@ def __init__( logger.info( "KVCacheGroupManager initialized: " - f"full_attn_group=({self.full_attn_group.group_id}, " - f"{self.full_attn_group.block_size}), " + f"lcm_block_size={self.lcm_block_size}, " + f"full_attn_groups=" + f"{[(g.group_id, g.block_size) for g in self.full_attn_groups]}, " f"sliding_window_groups=" f"{[(g.group_id, g.block_size, g.sliding_window) for g in self.sliding_window_groups]}" ) @@ -1355,22 +1401,18 @@ def compute_block_hashes( ret.append(hash_value) return ret - def compute_all_group_block_ids( - self, token_ids: list[int] - ) -> list[list[bytes]]: + def compute_all_group_block_ids(self, token_ids: list[int]) -> list[list[bytes]]: """Compute full block hashes for every group, indexed by group_id. ``ret[gid]`` covers all aligned blocks of ``token_ids`` using group ``gid``'s ``block_size`` and chain seed. The trailing partial block (if any) is dropped, matching :meth:`compute_block_hashes`. """ - return [ - self.compute_block_hashes(g, token_ids) for g in self.groups_by_id - ] + return [self.compute_block_hashes(g, token_ids) for g in self.groups_by_id] def lookup_external_hit_tokens( self, - hbm_hit_block_num: int, + num_computed_tokens: int, store: "UcmKVStoreBaseV1", group_block_ids: list[list[bytes]], ) -> tuple[int, int]: @@ -1379,63 +1421,89 @@ def lookup_external_hit_tokens( ``group_block_ids`` must have one entry per group, indexed by the original ``group_id`` (see :meth:`compute_all_group_block_ids`). + Stage 1 — every full-attention group runs ``lookup_on_prefix`` + beyond its own ``hbm_hit_block_num``; the candidate hits are taken + as a min and rounded down to ``lcm_block_size`` so the final + external hit is consistent across all full-attn groups and aligns + to the kv-cache page granularity expected by the scheduler. + + Stage 2 — every sliding-window group must have the last + ``sliding_window // block_size`` blocks before ``total_hit_tokens`` + present in the store; if any group fails, the whole external hit + is downgraded to zero. + Returns: Tuple of - - ``external_hit_tokens``: tokens hit beyond ``hbm_hit_block_num``, - aligned to the full-attn group's block size. ``0`` if any - sliding-window group fails its tail check. - - ``external_hit_blocks``: number of full-attn blocks hit beyond - ``hbm_hit_block_num`` (also ``0`` on downgrade). + - ``external_hit_tokens``: tokens hit beyond ``num_computed_tokens``, + aligned to ``lcm_block_size``. ``0`` if any check fails. + - ``external_hit_lcm_blocks``: ``external_hit_tokens // + lcm_block_size`` (also ``0`` on downgrade). """ assert len(group_block_ids) == self.num_groups, ( f"group_block_ids length {len(group_block_ids)} does not match " f"num_groups {self.num_groups}" ) + assert num_computed_tokens % self.lcm_block_size == 0, ( + f"num_computed_tokens={num_computed_tokens} is not aligned to " + f"lcm_block_size={self.lcm_block_size}" + ) - full_attn = self.full_attn_group - full_attn_block_ids = group_block_ids[full_attn.group_id] - - external_block_ids = full_attn_block_ids[hbm_hit_block_num:] - if not external_block_ids: - return 0, 0 - - try: - external_hit_blocks = store.lookup_on_prefix(external_block_ids) + 1 - except Exception as e: - logger.error( - f"full-attn group {full_attn.group_id} lookup error. " - f"{type(e).__name__}: {e}" - ) - return 0, 0 - - if external_hit_blocks <= 0: + # Stage 1: each full-attn group contributes a candidate hit count. + candidates: list[int] = [] + for fa in self.full_attn_groups: + fa_block_ids = group_block_ids[fa.group_id] + fa_hbm_blocks = num_computed_tokens // fa.block_size + fa_external = fa_block_ids[fa_hbm_blocks:] + if not fa_external: + candidates.append(0) + continue + try: + fa_hit_blocks = store.lookup_on_prefix(fa_external) + 1 + except Exception as e: + logger.error( + f"full-attn group {fa.group_id} lookup error. " + f"{type(e).__name__}: {e}" + ) + candidates.append(0) + continue + candidates.append(max(fa_hit_blocks, 0) * fa.block_size) + + # Resume boundary must be a multiple of lcm_block_size so every + # group's tail/dispatch slicing lands on a real block boundary. + min_external_hit_tokens = min(candidates) + external_hit_tokens = ( + min_external_hit_tokens // self.lcm_block_size + ) * self.lcm_block_size + if external_hit_tokens <= 0: return 0, 0 - external_hit_tokens = external_hit_blocks * full_attn.block_size - # Resume boundary: SW layers need ``[total_hit_tokens - sliding_window, - # total_hit_tokens)`` to be present in external storage, regardless of - # whether some of those tokens are also covered by ``hbm_hit_tokens`` - # (the SW manager has dropped them anyway when resuming). - total_hit_tokens = ( - hbm_hit_block_num * full_attn.block_size + external_hit_tokens - ) - + # Stage 2: every SW group's tail window must be in the store. + total_hit_tokens = num_computed_tokens + external_hit_tokens for sw in self.sliding_window_groups: - # Q3: not enough hit to fill a full sliding window — treat as miss. - if total_hit_tokens < sw.sliding_window: + # When ``block_size > sliding_window`` (e.g. on Ascend) a single + # block already covers more than a window of tokens, so we use + # the last block as the SW tail (tail_count = 1). Otherwise + # ``sliding_window`` is a multiple of ``block_size`` (validated + # in __init__) and we take exactly ``sliding_window/block_size`` + # blocks. + tail_count = max(1, sw.sliding_window // sw.block_size) + min_required_tokens = tail_count * sw.block_size + if total_hit_tokens < min_required_tokens: logger.info( f"sliding window group {sw.group_id} tail check skipped: " f"total_hit_tokens={total_hit_tokens} < " - f"sliding_window={sw.sliding_window}, downgrade to 0." + f"min_required={min_required_tokens} " + f"(sliding_window={sw.sliding_window}, " + f"block_size={sw.block_size}), downgrade to 0." ) return 0, 0 sw_block_ids = group_block_ids[sw.group_id] - # ``sw.block_size`` divides ``total_hit_tokens`` (block_size - # divides full_attn.block_size, validated in __init__), so this - # slice is exact and never crosses the group's last full block. + # ``sw.block_size`` divides ``total_hit_tokens`` because + # ``total_hit_tokens`` is a multiple of ``lcm_block_size`` and + # ``lcm_block_size`` is divisible by every group's block_size + # (validated in __init__). num_blocks_in_total_hit = total_hit_tokens // sw.block_size - tail_count = sw.sliding_window // sw.block_size tail_block_ids = sw_block_ids[ num_blocks_in_total_hit - tail_count : num_blocks_in_total_hit ] @@ -1454,7 +1522,7 @@ def lookup_external_hit_tokens( ) return 0, 0 - return external_hit_tokens, external_hit_blocks + return external_hit_tokens, external_hit_tokens // self.lcm_block_size class HMAKVCacheLayout(KVCacheLayout): @@ -1465,9 +1533,6 @@ def __init__( vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig", ): - self.layer_name_to_kv_cache_spec = layer_name_to_kv_cache_spec( - kv_cache_config - ) super().__init__(kvcaches, ucm_config, vllm_config, kv_cache_config) def _build_layout(self, kvcaches): @@ -1482,9 +1547,11 @@ def _build_layout(self, kvcaches): sample_layer_name = raw_tensor.shared_by[0] kv_layer = kvcaches.get(sample_layer_name) if kv_layer is None: - logger.warning(f"kv_layer {sample_layer_name} not found in kvcaches") + logger.warning( + f"kv_layer {sample_layer_name} not found in kvcaches" + ) continue - kv_cache_spec = self.layer_name_to_kv_cache_spec[sample_layer_name] + kv_cache_spec = self.layer_name_to_kv_cache_spec[sample_layer_name][0] if isinstance(kv_layer, torch.Tensor): ptrs.append(kv_layer.data_ptr()) tensor_sizes.append(kv_cache_spec.page_size_bytes) @@ -1508,11 +1575,82 @@ def _build_layout(self, kvcaches): ) +class AscendDSV4Layout(HMAKVCacheLayout): + def __init__( + self, + kvcaches, + ucm_config: dict, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + ): + super().__init__(kvcaches, ucm_config, vllm_config, kv_cache_config) + self.indexer_scale_size_bytes = 0 + for _, layer_specs in self.layer_name_to_kv_cache_spec.items(): + for spec in layer_specs: + if hasattr(spec, "indexer_scale_size_bytes"): + self.indexer_scale_size_bytes = spec.indexer_scale_size_bytes + break + + def _build_layout(self, kvcaches): + self.indexer_scale_size_bytes = 0 + for _, layer_specs in self.layer_name_to_kv_cache_spec.items(): + for spec in layer_specs: + if hasattr(spec, "indexer_scale_size_bytes"): + self.indexer_scale_size_bytes = spec.indexer_scale_size_bytes + break + + base_ptrs = [] + tensor_size_lists = [] + + for raw_tensor in self.kv_cache_config.kv_cache_tensors: + ptrs = [] + tensor_sizes = [] + kv_size = raw_tensor.size - self.indexer_scale_size_bytes * self.num_blocks + + if raw_tensor.shared_by: + sample_layer_name = raw_tensor.shared_by[0] + kv_layer = kvcaches.get(sample_layer_name) + if kv_layer is None: + logger.warning( + f"kv_layer {sample_layer_name} not found in kvcaches" + ) + continue + kv_cache_specs = self.layer_name_to_kv_cache_spec[sample_layer_name] + if isinstance(kv_layer, (tuple, list)): + ptrs.append(kv_layer[0].data_ptr()) + tensor_sizes.append( + kv_cache_specs[0].page_size_bytes + - self.indexer_scale_size_bytes + ) + ptrs.append(kv_layer[0].data_ptr() + kv_size) + tensor_sizes.append(self.indexer_scale_size_bytes) + else: + logger.warning(f"unsupported kv_layer type: {type(kv_layer)}") + + if not ptrs and not tensor_sizes: + continue + + base_ptrs.append(ptrs) + tensor_size_lists.append(tensor_sizes) + + self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) + self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) + + logger.info( + f"base_ptrs: {self.base_ptrs.shape}, tensor_size_lists: {self.tensor_size_lists.shape}" + ) + + class UCMHMAConnector(UCMDirectConnector, SupportsHMA): def __init__( - self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig" + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", ): - super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) # group manager only lives on the scheduler side, where ``self._seed`` # and ``self.request_hasher`` are populated by the parent ctor. self.group_manager: Optional[KVCacheGroupManager] = None @@ -1522,23 +1660,35 @@ def __init__( request_hasher=self.request_hasher, base_seed=self._seed, ) - full_attn_block_size = self.group_manager.full_attn_group.block_size + lcm_block_size = self.group_manager.lcm_block_size # Override the inherited ``block_size`` (which comes from # ``cache_config.block_size``) so prefix accounting in this class - # is consistent with the full-attn group. - self.block_size = full_attn_block_size - self.hash_block_size = full_attn_block_size - - logger.info(f"UCMHMAConnector initialized with use_layerwise={self.use_layerwise}") + # is consistent with every group's block boundaries — vLLM's + # hybrid scheduler aligns ``num_computed_tokens`` to the LCM of + # all groups' block_size, and so do we. + self.block_size = lcm_block_size + self.hash_block_size = lcm_block_size + + logger.info( + f"UCMHMAConnector initialized with use_layerwise={self.use_layerwise}" + ) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches - self.kv_cache_layout = HMAKVCacheLayout( - self.kv_caches, - self.launch_config, - self._vllm_config, - self._kv_cache_config, - ) + if current_platform.device_type == "npu" and self.use_compress: + self.kv_cache_layout = AscendDSV4Layout( + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, + ) + else: + self.kv_cache_layout = HMAKVCacheLayout( + self.kv_caches, + self.launch_config, + self._vllm_config, + self._kv_cache_config, + ) self.store = self._create_store(self.kv_cache_layout) self.block_data_size = self.kv_cache_layout.block_size self.device = create_device() @@ -1551,13 +1701,15 @@ def get_num_new_matched_tokens( "connector, where the group manager is initialized." ) - full_attn = self.group_manager.full_attn_group - full_attn_block_size = full_attn.block_size - assert num_computed_tokens % full_attn_block_size == 0, ( + lcm_block_size = self.group_manager.lcm_block_size + assert num_computed_tokens % lcm_block_size == 0, ( f"num_computed_tokens={num_computed_tokens} is not aligned to " - f"full-attn group block_size={full_attn_block_size}" + f"lcm_block_size={lcm_block_size}" ) - hbm_hit_block_num = num_computed_tokens // full_attn_block_size + # ``hbm_hit_block_num`` and ``total_hit_block_num`` are tracked in + # LCM-block units in HMA mode; per-group block ids/counts are derived + # from these via each group's own block_size when needed. + hbm_hit_block_num = num_computed_tokens // lcm_block_size # Skip persistence if token count is below the threshold. if self.persist_token_threshold > request.num_tokens: @@ -1572,20 +1724,23 @@ def get_num_new_matched_tokens( group_ucm_block_ids = self.group_manager.compute_all_group_block_ids( request.all_token_ids ) - full_attn_block_ids = group_ucm_block_ids[full_attn.group_id] + # Legacy ``ucm_block_ids`` mirrors the first full-attn group (by + # group_id order) for callers that still consume the flat list. + primary_full_attn = self.group_manager.full_attn_groups[0] + primary_block_ids = group_ucm_block_ids[primary_full_attn.group_id] - external_hit_tokens, external_hit_blocks = ( + external_hit_tokens, external_hit_lcm_blocks = ( self.group_manager.lookup_external_hit_tokens( - hbm_hit_block_num, self.store, group_ucm_block_ids + num_computed_tokens, self.store, group_ucm_block_ids ) ) if ( self.enable_record_traces and request.request_id not in self.requests_meta - and len(full_attn_block_ids) > 0 + and len(primary_block_ids) > 0 ): - hex_block_ids = [b.hex() for b in full_attn_block_ids] + hex_block_ids = [b.hex() for b in primary_block_ids] logger.info_once( f"timestamp: {time.perf_counter()}, " f"input_length: {request.num_tokens}, " @@ -1593,31 +1748,33 @@ def get_num_new_matched_tokens( f"ucm_block_ids: {hex_block_ids}" ) - total_hit_block_num = hbm_hit_block_num + external_hit_blocks + total_hit_block_num = hbm_hit_block_num + external_hit_lcm_blocks logger.info_once( f"request_id: {request.request_id}, " - f"total_blocks_num: {len(full_attn_block_ids)}, " + f"total_lcm_blocks: {request.num_tokens // lcm_block_size}, " f"hit hbm: {hbm_hit_block_num}, " - f"hit external: {external_hit_blocks}" + f"hit external: {external_hit_lcm_blocks}, " + f"total_tokens: {len(request.all_token_ids)}" ) - if self.metrics_config and len(full_attn_block_ids) > 0: + if self.metrics_config and len(primary_block_ids) > 0: ucmmetrics.update_stats( { - "interval_lookup_hit_rates": external_hit_blocks - / len(full_attn_block_ids) + "interval_lookup_hit_rates": external_hit_lcm_blocks + * lcm_block_size + / (len(primary_block_ids) * primary_full_attn.block_size) }, ) # When all the tokens are cached in ssd or hbm, we need to recompute # the last token. This branch will be removed once vLLM scheduler # provides a better solution in the future. - num_total_hit_tokens = total_hit_block_num * full_attn_block_size + num_total_hit_tokens = total_hit_block_num * lcm_block_size if num_total_hit_tokens == request.num_tokens and external_hit_tokens > 0: external_hit_tokens -= 1 self.requests_meta[request.request_id] = HMARequestMeta( - ucm_block_ids=full_attn_block_ids, + ucm_block_ids=primary_block_ids, hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, num_token_ids=len(request.all_token_ids), @@ -1647,19 +1804,26 @@ def _generate_hma_dispatch_meta( Layout per group within ``[token_processed, token_processed + new_tokens)``: - **load** (only when ``external_hit_blocks > 0`` and ``need_load``): - full-attn group: tokens ``[hbm_hit_tokens, total_hit_tokens)`` - - sliding-window group: tokens - ``[total_hit_tokens - sliding_window, total_hit_tokens)`` - (the SW window is reloaded every resume because older blocks are - evicted by the SW manager). - - **dump**: every newly-completed full block of every group inside - ``[token_processed, token_processed + new_tokens)`` (Option A: SW - groups dump *all* blocks, not only the tail, so future requests - can hit at any full-attn boundary). + - sliding-window group: the blocks covering tokens + ``[total_hit_tokens - sliding_window, total_hit_tokens)``; + ``start_blk`` is floor-rounded so when ``block_size > + sliding_window`` (e.g. on Ascend) we naturally load the single + last block. The SW window is reloaded every resume because + older blocks are evicted by the SW manager. + - **dump** of ``[token_processed, token_processed + new_tokens)``: + - full-attn group: every newly-completed full block (the + ``lookup_on_prefix`` chain needs every prefix block to be + present). + - sliding-window group: only the last + ``max(1, sliding_window/block_size)`` blocks before each LCM + boundary reached in this range. Lookup always resumes at LCM + boundaries and stage-2 SW check only inspects those tails, so + blocks between tails would be dead weight in the store. """ assert self.group_manager is not None groups_by_id = self.group_manager.groups_by_id num_groups = self.group_manager.num_groups - full_attn_bs = self.group_manager.full_attn_group.block_size + lcm_block_size = self.group_manager.lcm_block_size assert len(new_vllm_block_ids_per_group) == num_groups, ( f"new_vllm_block_ids_per_group length " @@ -1667,22 +1831,20 @@ def _generate_hma_dispatch_meta( f"num_groups {num_groups}" ) for gid in range(num_groups): - req_meta.group_vllm_block_ids[gid].extend( - new_vllm_block_ids_per_group[gid] - ) + req_meta.group_vllm_block_ids[gid].extend(new_vllm_block_ids_per_group[gid]) load_ucm_block_ids: list[bytes] = [] load_vllm_block_ids: list[int] = [] dump_ucm_block_ids: list[bytes] = [] dump_vllm_block_ids: list[int] = [] - external_hit_blocks = ( + external_hit_lcm_blocks = ( req_meta.total_hit_block_num - req_meta.hbm_hit_block_num ) - hbm_hit_tokens = req_meta.hbm_hit_block_num * full_attn_bs - total_hit_tokens = req_meta.total_hit_block_num * full_attn_bs + hbm_hit_tokens = req_meta.hbm_hit_block_num * lcm_block_size + total_hit_tokens = req_meta.total_hit_block_num * lcm_block_size - if need_load and external_hit_blocks > 0: + if need_load and external_hit_lcm_blocks > 0: for gid, group in enumerate(groups_by_id): if group.is_full_attention: load_tok_start = hbm_hit_tokens @@ -1705,17 +1867,49 @@ def _generate_hma_dispatch_meta( dump_tok_end = min( req_meta.token_processed + new_tokens, req_meta.num_token_ids ) + # LCM boundaries B with ``dump_tok_start < B <= dump_tok_end``. + # SW groups only need the tail at these boundaries because lookup + # always resumes at LCM boundaries (see + # ``lookup_external_hit_tokens`` stage 2). + first_lcm_b = (dump_tok_start // lcm_block_size + 1) * lcm_block_size + last_lcm_b = (dump_tok_end // lcm_block_size) * lcm_block_size + for gid, group in enumerate(groups_by_id): - start_blk = dump_tok_start // group.block_size - end_blk = dump_tok_end // group.block_size - if start_blk >= end_blk: - continue - dump_ucm_block_ids.extend( - req_meta.group_ucm_block_ids[gid][start_blk:end_blk] - ) - dump_vllm_block_ids.extend( - req_meta.group_vllm_block_ids[gid][start_blk:end_blk] - ) + if group.is_full_attention: + # Dump every newly completed block: ``lookup_on_prefix`` + # walks the full prefix chain so any gap would truncate + # future hits. + start_blk = dump_tok_start // group.block_size + end_blk = dump_tok_end // group.block_size + if start_blk >= end_blk: + continue + dump_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + dump_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + else: + # Dump only the tail blocks at each LCM boundary reached + # in this range. Since ``sliding_window <= + # lcm_block_size`` (validated in ``KVCacheGroupManager``), + # consecutive boundaries' tails do not overlap and we can + # extend the lists without dedup. + if first_lcm_b > last_lcm_b: + continue + tail_count = max(1, group.sliding_window // group.block_size) + b = first_lcm_b + while b <= last_lcm_b: + end_blk = b // group.block_size + start_blk = max(0, end_blk - tail_count) + if start_blk < end_blk: + dump_ucm_block_ids.extend( + req_meta.group_ucm_block_ids[gid][start_blk:end_blk] + ) + dump_vllm_block_ids.extend( + req_meta.group_vllm_block_ids[gid][start_blk:end_blk] + ) + b += lcm_block_size req_meta.token_processed += new_tokens return RequestDispatchMeta( @@ -1799,8 +1993,15 @@ def request_finished_all_groups( class UCMConnector(KVConnectorBase_V1, SupportsHMA): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig"): - super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self.connector: KVConnectorBase_V1 ucm_config = Config(vllm_config.kv_transfer_config) self.launch_config = ucm_config.get_config() @@ -1838,7 +2039,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_co > 1 ) - use_hma = self._vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False + use_hma = ( + self._vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False + or os.getenv("USE_MULTI_GROUPS_KV_CACHE") == "1" + ) if use_lite: self.connector = UCMLiteConnector(vllm_config, role) @@ -2009,4 +2213,4 @@ def request_finished_all_groups( request: "Request", block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: - return self.connector.request_finished_all_groups(request, block_ids) \ No newline at end of file + return self.connector.request_finished_all_groups(request, block_ids)