Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -

# Process image and text
inputs = self._processor(images=image, text=prompt, return_tensors="pt")

if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen2_5_vl"
):
if hasattr(self._qeff_model.model.config, "model_type") and self._qeff_model.model.config.model_type in {
"qwen2_5_vl",
"qwen3_vl_moe",
"qwen3_vl",
}:
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)
Expand Down Expand Up @@ -411,7 +411,7 @@ def setup_vision_buffers(self):
buffers = {}
for output_name, shape in shapes.items():
# Create placeholder with appropriate dtype
if "vision_embeds" in output_name:
if "vision_embeds" in output_name or "deepstack_features" in output_name:
buffers[output_name] = np.zeros(shape, dtype=np.float16)
else:
buffers[output_name] = np.zeros(shape, dtype=np.float32)
Expand Down
24 changes: 11 additions & 13 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def __init__(
)

# Vision-specific initialization
self.is_qwen2_5_vl = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
)
self.is_qwen_vl = hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type in {
"qwen2_5_vl",
"qwen3_vl_moe",
"qwen3_vl",
}
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
Expand Down Expand Up @@ -256,13 +258,12 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
outputs, position_ids, generation_len = self.run_prefill(
next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1)
)

if self.is_qwen2_5_vl:
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
if self.is_qwen_vl:
_ = self.update_decode_inputs_qwen_vl(outputs, position_ids, generation_len, decode_batch_id)
else:
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
def update_decode_inputs_qwen_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
"""
Updates the decode input with the generated values.
Args:
Expand Down Expand Up @@ -581,14 +582,12 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len)

self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
if self.is_qwen2_5_vl:
if self.is_qwen_vl:
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)

# Create prompt queue
prompt_queue = deque(vision_prompts)

start = perf_counter()

# Pre-process ALL vision inputs and cache them
logger.info("Pre-processing all vision inputs...")
for batch_id in range(min(self.full_batch_size, len(vision_prompts))):
Expand All @@ -610,7 +609,6 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,

# Reset prompt queue for prefill
prompt_queue = deque(vision_prompts)

self.batch_index = None

# Run prefill for all inputs using cached vision
Expand Down Expand Up @@ -692,8 +690,8 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
generation_len_final = self._fetch_generation_len(generation_len, max_gen_len)

# Update decode inputs
if self.is_qwen2_5_vl:
self.update_decode_inputs_qwen2_5_vl(
if self.is_qwen_vl:
self.update_decode_inputs_qwen_vl(
outputs, position_ids_decode, generation_len_final, decode_batch_id
)
else:
Expand Down
45 changes: 40 additions & 5 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def _get_invalid_idx_value(cls):


class QEffDynamicLayer(DynamicLayer):
def lazy_initialization(self, key_states: torch.Tensor):
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True

def read_only(self, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer.
Expand Down Expand Up @@ -186,10 +192,12 @@ def update(
A tuple containing the updated key and value states.
"""
# Update the cache

if self.keys is None:
self.keys = key_states
self.values = value_states
k_out, v_out = self.keys, self.values
self.is_initialized = True
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
Expand Down Expand Up @@ -306,15 +314,42 @@ class QEffDynamicCache(DynamicCache):

"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
config=None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
*args,
**kwargs,
):
# Remove layer_classes if present to avoid duplicate argument
kwargs.pop("layer_classes", None)
kwargs.pop("layers", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if len(layers) == 0:
Cache.__init__(
self,
layer_class_to_replicate=QEffDynamicLayer,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)
else:
Cache.__init__(
self,
layers=layers,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)

if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
layers.append(QEffDynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)

def read_only(self, layer_idx, cache_kwargs):
"""
Expand Down
Loading
Loading