From 1200545f55606d56f1241507a5fa0a57c83c1199 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Tue, 17 Mar 2026 10:35:26 +0000 Subject: [PATCH] TF Upgrade to 4.57.0 Signed-off-by: Dipankar Sarkar --- QEfficient/transformers/cache_utils.py | 50 +++++++++++++++++-- .../models/falcon/modeling_falcon.py | 2 +- .../models/gemma/modeling_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- .../models/llama/modeling_llama.py | 2 +- pyproject.toml | 2 +- 6 files changed, 50 insertions(+), 10 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 0e1118407..3bf747058 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -10,6 +10,8 @@ from typing import Any, Dict, List, Optional, Tuple import torch +import transformers +from packaging import version from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( @@ -55,6 +57,12 @@ def _get_invalid_idx_value(cls): class QEffDynamicLayer(DynamicLayer): + def lazy_initialization(self, key_states: torch.Tensor): + self.dtype, self.device = key_states.dtype, key_states.device + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + self.is_initialized = True + def read_only(self, cache_kwargs): """ Reads the `key_states` and `value_states` for the layer. @@ -186,10 +194,12 @@ def update( A tuple containing the updated key and value states. """ # Update the cache + if self.keys is None: self.keys = key_states self.values = value_states k_out, v_out = self.keys, self.values + self.is_initialized = True else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -306,15 +316,45 @@ class QEffDynamicCache(DynamicCache): """ - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + def __init__( + self, + ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, + config=None, + offloading: bool = False, + offload_only_non_sliding: bool = False, + *args, + **kwargs, + ): # Remove layer_classes if present to avoid duplicate argument - kwargs.pop("layer_classes", None) + kwargs.pop("layers", None) from transformers.cache_utils import Cache # Import here to avoid circular import - Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + layers = [] + # If a config is passed, use it to infer the layer types and initialize accordingly + if len(layers) == 0: + if version.parse(transformers.__version__) < version.parse("4.57.0"): + Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + else: + Cache.__init__( + self, + layer_class_to_replicate=QEffDynamicLayer, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + else: + Cache.__init__( + self, + layers=layers, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data): + # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data + layers.append(QEffDynamicLayer()) + # Update the layer with the data + _, _ = layers[layer_idx].update(key_states, value_states) def read_only(self, layer_idx, cache_kwargs): """ diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4ebb2fb96..2f2b34e37 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -137,7 +137,7 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length() cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 260d1857a..e154063e8 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -149,7 +149,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length() cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 6dee8c85d..c6ef79c24 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -156,7 +156,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length() cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 57bccdb1b..a01d75a8f 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -226,7 +226,7 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length() past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/pyproject.toml b/pyproject.toml index 6de8048b4..0280ec2ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "transformers==4.55.0", + "transformers==4.57.0", "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9",