diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 269ccb0be..359bdbed8 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -56,6 +56,7 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) +# TODO : This function will be deprecated in future. @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: # Find dims @@ -75,6 +76,7 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates return ops.ScatterND(data, indices, updates) +# TODO : This function will be deprecated in future. class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): @@ -92,6 +94,7 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) +# TODO : This function will be deprecated in future. @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) @@ -99,6 +102,7 @@ def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxsc return ops.GatherND(data, ctx_indices, batch_dims=1) +# TODO : This function will be deprecated in future. class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index cc9693716..bb97feb93 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -56,6 +56,7 @@ def symbolic( return g.onnxscript_op(CtxScatterCB, data, batch_index, position_ids, updates).setTypeAs(data) +# TODO : This function will be deprecated in future. @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatterCB3D( data: onnxscript.FLOAT, batch_index: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT @@ -76,6 +77,7 @@ def CtxScatterCB3D( return ops.ScatterND(data, indices, updates) +# TODO : This function will be deprecated in future. class CtxScatterFuncCB3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, batch_index: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): @@ -139,6 +141,7 @@ def symbolic( return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) +# TODO : This function will be deprecated in future. @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB3D( data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 @@ -158,6 +161,7 @@ def CtxGatherCB3D( return ops.GatherND(data, indices) +# TODO : This function will be deprecated in future. class CtxGatherFuncCB3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 5452589f6..aad7050c6 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -1,666 +1,651 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - - -from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple - -import torch -from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache - -from QEfficient.customop import ( - CtxGatherFunc, - CtxGatherFunc3D, - CtxGatherFuncCB, - CtxGatherFuncCB3D, - CtxScatterFunc, - CtxScatterFunc3D, - CtxScatterFuncCB, - CtxScatterFuncCB3D, -) - - -class QEffDynamicLayer(DynamicLayer): - def read_only(self, cache_kwargs): - """ - Reads the `key_states` and `value_states` for the layer. - - Parameters: - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Gather - k_out, v_out = self.keys, self.values - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) - ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) - - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) - - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - return k_out, v_out - - def write_only(self, key_states, value_states, cache_kwargs): - """ - Write in the cache with the new `key_states` and `value_states` for the layer. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - """ - # Update the cache - if self.keys is None: - self.keys = key_states - self.values = value_states - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) - self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) - else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.keys is None: - self.keys = key_states - self.values = value_states - k_out, v_out = self.keys, self.values - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) - - self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) - else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) - - k_out, v_out = self.keys, self.values - - # Gather - ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - - return k_out, v_out - - # TODO:This function will be depercated in future. - def update3D( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if self.keys is None: - self.keys = key_states - self.values = value_states - k_out, v_out = self.keys, self.values - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.keys = CtxScatterFuncCB3D.apply(self.keys, batch_index, scatter_position_ids, key_states) - - self.values = CtxScatterFuncCB3D.apply(self.values, batch_index, scatter_position_ids, value_states) - else: - self.keys = CtxScatterFunc3D.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc3D.apply(self.values, position_ids, value_states) - - k_out, v_out = self.keys, self.values - - # Gather - ctx_len = k_out.shape[1] - ctx_indices = torch.arange(ctx_len)[None, ...] - gather_limit = position_ids.max(1, keepdim=True).values - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - if batch_index is not None: - k_out = CtxGatherFuncCB3D.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB3D.apply(v_out, batch_index, ctx_indices) - else: - k_out = CtxGatherFunc3D.apply(k_out, ctx_indices) - v_out = CtxGatherFunc3D.apply(v_out, ctx_indices) - - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - - return k_out, v_out - - -class QEffDynamicCache(DynamicCache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - - Optimized implementation for the Cloud AI 100 to reuse KV Cache. - - get the position_ids input using kwargs. - - Use custom Onnxscript ops to write optimized version to generate Onnx model. - - """ - - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - # Remove layer_classes if present to avoid duplicate argument - kwargs.pop("layer_classes", None) - from transformers.cache_utils import Cache # Import here to avoid circular import - - Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) - if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) - - def read_only(self, layer_idx, cache_kwargs): - """ - Reads the `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - return self.layers[layer_idx].read_only(cache_kwargs) - - def write_only(self, key_states, value_states, layer_idx, cache_kwargs): - """ - Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - """ - self.append_new_layers(layer_idx) - return self.layers[layer_idx].write_only(key_states, value_states, cache_kwargs) - - # TODO:This function will be depercated in future. - def update3D( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - self.append_new_layers(layer_idx) - return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs) - - -class QEffEncoderDecoderCache(EncoderDecoderCache): - """ - Updated the `EncoderDecoderCache` to use the `QEffDynamicCache` for both self-attention and cross-attention caches. - """ - - @classmethod - def from_legacy_cache( - cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "EncoderDecoderCache": - """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls( - self_attention_cache=QEffDynamicCache(), - cross_attention_cache=QEffDynamicCache(), - ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True - return cache - - -# TODO:This function will be depercated in future. -class QEffHybridCache(HybridCache): - def __init__(self, config, batch_size, max_cache_len): - super().__init__(config, batch_size, max_cache_len=max_cache_len) - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - - @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") - is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) - layer_ctx_len = self.key_cache[layer_idx].shape[2] - kv_position_ids = torch.where( - (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) - ) - - kv_position_ids = torch.where( - is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), - (position_ids + 1) % layer_ctx_len, - kv_position_ids, - ) - - valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) - key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) - value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Original Gather - ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) - rolling_indices = rolling_indices[:ctx_len] - final_indices = torch.where( - (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices - ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) - ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) - return k_out, v_out - - -# TODO:This function will be depercated in future. -class QEffHybridChunkedCache(HybridChunkedCache): - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `HybridChunkedCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridChunkedCache": - """Converts a cache in the legacy cache format into an equivalent `HybridChunkedCache`. Used for - backward compatibility.""" - cache = cls(config, max_batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - - else: - position_ids = cache_kwargs.get("position_ids") - is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) - - # Update the position_ids to handle the sliding window - layer_ctx_len = self.key_cache[layer_idx].shape[2] - kv_position_ids = torch.where( - (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) - ) - - kv_position_ids = torch.where( - is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), - (position_ids + 1) % layer_ctx_len, - kv_position_ids, - ) - - valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) - key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) - value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Original Gather - ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) - ctx_len = min(layer_ctx_len, ctx_len) - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - # Rolling indices for sliding window - all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) - rolling_indices = rolling_indices[:ctx_len] - final_indices = torch.where( - (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices - ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) - ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) - return k_out, v_out - - -# This is a hack for now, until we get to merging this code with HybridCache class, -# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and -# ours are made to work with AIC -class QEffHybridCacheForGPTOSS: - def __init__(self, config, batch_size, max_cache_len, sliding_window_len): - self.max_cache_len = max_cache_len - self.batch_size = batch_size - self.sliding_window_len = sliding_window_len - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - - @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls( - config, - batch_size=past_key_values[0][0].shape[0], - max_cache_len=past_key_values[1][0].shape[2], - sliding_window_len=past_key_values[0][0].shape[2], - ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - is_sliding_layer = cache_kwargs.get("is_sliding") - sliding_window = cache_kwargs.get("sliding_window") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs - - if is_sliding_layer: - kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) - else: - kv_position_ids = position_ids - - if batch_index is not None: - if torch.onnx.is_in_onnx_export(): - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids) - else: - scatter_position_ids = kv_position_ids - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) - else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) - - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Original Gather - if is_sliding_layer: - ctx_len = self.key_cache[layer_idx].shape[2] - else: - ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) - - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) - - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - return k_out, v_out +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from abc import ABC +from collections.abc import Iterable +from typing import Any, Dict, Optional, Tuple + +import torch + +from QEfficient.customop import ( + CtxGatherFunc, + CtxGatherFuncCB, + CtxScatterFunc, + CtxScatterFuncCB, +) +from QEfficient.utils.constants import INVALID_IDX + + +class QEffDynamicLayer(ABC): + def __init__(self): + self.keys, self.values = None, None + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + if self.keys is None or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] + + @classmethod + def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "QEffDynamicLayer": + """ + Build a `QEffDynamicLayer` instance from pre-existing key/value tensors. + + Args: + keys (`torch.Tensor`): + Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + values (`torch.Tensor`): + Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``. + + Returns: + `QEffDynamicLayer`: The newly constructed layer whose internal cache directly references + the supplied tensors. + """ + layer = cls() + layer.keys = keys + layer.values = values + return layer + + def read_only(self, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Gather + k_out, v_out = self.keys, self.values + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = INVALID_IDX + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + def write_only(self, key_states, value_states, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, INVALID_IDX, position_ids) + else: + scatter_position_ids = position_ids + + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, INVALID_IDX, position_ids) + else: + scatter_position_ids = position_ids + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + + k_out, v_out = self.keys, self.values + + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = INVALID_IDX + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + +class QEffDynamicCache: + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + - Optimized implementation for the Cloud AI 100 to reuse KV Cache. + - get the position_ids input using kwargs. + - Use custom Onnxscript ops to write optimized version to generate Onnx model. + + """ + + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + self.layers = [] + self.layer_classes = QEffDynamicLayer + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ + if getattr(self, "layers", None) is None: + if getattr(self, "key_cache", None) is not None: + return len(self.key_cache) + return 0 + # Empty dynamic caches initialize an empty layer to be ready for first update + dynamic_empty = ( + getattr(self, "layers", None) is not None + and len(self.layers) == 1 + and isinstance(self.layers[0], QEffDynamicLayer) + and self.layers[0].keys is None + ) + return len(self.layers) if not dynamic_empty else 0 + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: + """ + Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility. + """ + return tuple((layer.keys, layer.values) for layer in self.layers) + + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] + ) -> "QEffDynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def append_new_layers(self, layer_idx: int) -> None: + """ + Appends layers to the cache until the layer `layer_idx` is reached. + Used for preallocation in static caches and on the fly in dynamic caches. + + Args: + layer_idx (`int`): + The index of the layer to append. + """ + while len(self.layers) <= layer_idx: + new_layer_class = ( + self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes + ) + new_layer = new_layer_class() + self.layers.append(new_layer) + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + return self.layers[layer_idx].get_seq_length() + + def read_only(self, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only(cache_kwargs) + + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only(key_states, value_states, cache_kwargs) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + + +class QEffEncoderDecoderCache: + """ + Updated the `EncoderDecoderCache` to use the `QEffDynamicCache` for both self-attention and cross-attention caches. + """ + + def __init__(self, self_attention_cache, cross_attention_cache): + # self.layers = [] + # self.layer_classes = QEffDynamicLayer + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "QEffEncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls( + self_attention_cache=QEffDynamicCache(), + cross_attention_cache=QEffDynamicCache(), + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` + return self.self_attention_cache.get_seq_length(layer_idx, cache_position) + + +# TODO:This function will be deprecated in future. +class QEffHybridCache(QEffDynamicCache): + def __init__(self, config, batch_size, max_cache_len): + layer_classes = [QEffHybridCacheLayer] * config.num_hidden_layers + self.layers = [] + self.layer_classes = layer_classes + self.max_cache_len = max_cache_len + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "QEffHybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class QEffHybridCacheLayer(QEffDynamicLayer): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") + + # remove layer_idx + is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) + + layer_ctx_len = self.keys.shape[2] + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (layer_ctx_len - 1)) + + kv_position_ids = torch.where( + position_ids.max() >= (layer_ctx_len - 1) * 2, (position_ids + 1) % layer_ctx_len, kv_position_ids + ) + else: + kv_position_ids = position_ids + + valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) + key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) + value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(kv_position_ids < 0, INVALID_IDX, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, kv_position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, kv_position_ids, value_states) + k_out, v_out = self.keys, self.values + + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = INVALID_IDX + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + +# TODO:This function will be deprecated in future. +class QEffHybridChunkedCache(QEffDynamicCache): + def __init__(self, config, max_batch_size, max_cache_len): + layer_classes = [QEffHybridChunkedLayer] * config.num_hidden_layers + self.layers = [] + self.layer_classes = layer_classes + self.max_cache_len = max_cache_len + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "QEffHybridChunkedCache": + """Converts a cache in the legacy cache format into an equivalent `HybridChunkedCache`. Used for + backward compatibility.""" + cache = cls(config, max_batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class QEffHybridChunkedLayer(QEffDynamicLayer): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + + # handle layer_idx and self.is_sliding + is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) + + # Update the position_ids to handle the sliding window + layer_ctx_len = self.keys.shape[2] + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (layer_ctx_len - 1)) + + kv_position_ids = torch.where( + position_ids.max() >= (layer_ctx_len - 1) * 2, (position_ids + 1) % layer_ctx_len, kv_position_ids + ) + else: + kv_position_ids = position_ids + + valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) + key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) + value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(kv_position_ids < 0, INVALID_IDX, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, kv_position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, kv_position_ids, value_states) + k_out, v_out = self.keys, self.values + + # Gather + ctx_len = min(layer_ctx_len, k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = INVALID_IDX + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS(QEffDynamicCache): + def __init__(self, config, max_cache_len, sliding_window_len): + layer_classes = [QEffGPTOSSLayer] * config.num_hidden_layers + self.layers = [] + self.layer_classes = layer_classes + self.max_cache_len = max_cache_len + self.sliding_window_len = sliding_window_len + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "QEffHybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class QEffGPTOSSLayer(QEffDynamicLayer): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(kv_position_ids < 0, INVALID_IDX, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, kv_position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, kv_position_ids, value_states) + k_out, v_out = self.keys, self.values + + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = INVALID_IDX + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 1cfdf88e1..5744033ee 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -136,7 +136,7 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 1edb8ef53..8eff3a930 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -149,7 +149,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 2944601c9..3435f2c46 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -156,7 +156,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index aa14554b2..1adcccf89 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -141,7 +141,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index b158b4046..82dd08b3e 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -139,7 +139,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 73b947dba..59cb50846 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -150,7 +150,7 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 5edfb8f3a..31713ee60 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -159,7 +159,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index a3cb4273d..2c1dcb856 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -274,7 +274,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index c088158c4..fe49bef69 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -265,15 +265,13 @@ def attention( v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) if self.config.use_position_ids and self.config.rope: - kv_seq_len = k.shape[-2] kv_seq_len = layer_past.get_seq_length(self.layer_id) # Apply rotary embeddings cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) if not self.config.use_position_ids and self.config.rope: - kv_seq_len = k.shape[-2] - kv_seq_len = layer_past.get_seq_length(kv_seq_len, self.layer_id) + kv_seq_len = layer_past.get_seq_length(self.layer_id) # Apply rotary embeddings cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 00755cae5..deffa8c53 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -150,7 +150,7 @@ def forward( kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b97a0ab8d..0e93b6208 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -157,7 +157,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 7c093a4b0..e1b1d6262 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -162,7 +162,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index baffb44c5..37d150320 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -503,7 +503,7 @@ def forward( value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 540bad4c7..776d3eef4 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -163,7 +163,7 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index cbd80d8ca..e7dc367e2 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -213,7 +213,7 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index a03ffecf7..115b99eff 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -75,8 +75,8 @@ def forward( if self.is_decoder: if is_cross_attention and past_key_value: # cross_attentions - key_states_old = past_key_value[self.layer_idx][0] - value_states_old = past_key_value[self.layer_idx][1] + key_states_old = past_key_value.layers[self.layer_idx].keys + value_states_old = past_key_value.layers[self.layer_idx].values key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() @@ -489,11 +489,6 @@ def forward( if not isinstance(past_key_values, Cache): return_legacy_cache = True past_key_values = QEffEncoderDecoderCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) if cache_position is None: cache_position = position_ids diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 1504bdae5..ecb0c3e74 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -8,6 +8,8 @@ import os from dataclasses import dataclass +import torch + UTILS_DIR = os.path.dirname(os.path.abspath(__file__)) QEFF_DIR = os.path.dirname(UTILS_DIR) ROOT_DIR = os.path.dirname(QEFF_DIR) @@ -46,6 +48,9 @@ # Minimum value for causal mask MIN_MASKED_ATTENTION_VALUE = float("-inf") +# Invalid index for position ids to be used during reading and writing the cache +INVALID_IDX = torch.iinfo(torch.int32).max + # Store the qeff_models inside the ~/.cache directory or over-ride with an env variable. def get_models_dir():