Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
import transformers
from packaging import version
from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache

from QEfficient.customop import (
Expand Down Expand Up @@ -55,6 +57,12 @@ def _get_invalid_idx_value(cls):


class QEffDynamicLayer(DynamicLayer):
def lazy_initialization(self, key_states: torch.Tensor):
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True

def read_only(self, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer.
Expand Down Expand Up @@ -186,10 +194,12 @@ def update(
A tuple containing the updated key and value states.
"""
# Update the cache

if self.keys is None:
self.keys = key_states
self.values = value_states
k_out, v_out = self.keys, self.values
self.is_initialized = True
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
Expand Down Expand Up @@ -306,15 +316,45 @@ class QEffDynamicCache(DynamicCache):

"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
config=None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
*args,
**kwargs,
):
# Remove layer_classes if present to avoid duplicate argument
kwargs.pop("layer_classes", None)
kwargs.pop("layers", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if len(layers) == 0:
if version.parse(transformers.__version__) < version.parse("4.57.0"):
Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
else:
Cache.__init__(
self,
layer_class_to_replicate=QEffDynamicLayer,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)
else:
Cache.__init__(
self,
layers=layers,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)

if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
layers.append(QEffDynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)

def read_only(self, layer_idx, cache_kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
kv_seq_len = past_key_value.get_seq_length()
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)

Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def forward(
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
kv_seq_len = past_key_value.get_seq_length()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
kv_seq_len = past_key_value.get_seq_length()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def forward(
key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
kv_seq_len = past_key_value.get_seq_length()
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]
requires-python = ">=3.8,<3.13"
dependencies = [
"transformers==4.55.0",
"transformers==4.57.0",
"diffusers== 0.35.1",
"huggingface-hub==0.34.0",
"hf_transfer==0.1.9",
Expand Down
Loading