From 6d7f593e95e43a3192d84278c3ef08a9fb4ce062 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 20 Feb 2026 12:18:49 +0000 Subject: [PATCH 01/14] Qwen3Vl Signed-off-by: Dipankar Sarkar --- examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py | 0 pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py b/examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyproject.toml b/pyproject.toml index a1082fdfe..abf55a725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "transformers==4.55.0", + "transformers==4.57.0", "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", From 7972c26967f384c65e006c15fa16f01cce90cf07 Mon Sep 17 00:00:00 2001 From: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Date: Mon, 23 Feb 2026 11:13:29 +0530 Subject: [PATCH 02/14] Add fp8 support (#802) Signed-off-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar Signed-off-by: Onkar Chougule Co-authored-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar --- QEfficient/transformers/cache_utils.py | 67 +- .../transformers/models/modeling_auto.py | 15 +- .../transformers/models/pytorch_transforms.py | 141 +-- .../models/qwen3_vl_moe/__init__.py | 6 + .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 1112 +++++++++++++++++ .../quantizers/quant_transforms.py | 63 +- .../quantizer_compressed_tensors.py | 171 ++- .../quantizers/quantizer_utils.py | 24 + examples/qwen3_vl.py | 144 +++ examples/qwen3_vl_moe.py | 134 ++ 10 files changed, 1771 insertions(+), 106 deletions(-) create mode 100644 QEfficient/transformers/models/qwen3_vl_moe/__init__.py create mode 100644 QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py create mode 100644 examples/qwen3_vl.py create mode 100644 examples/qwen3_vl_moe.py diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 0e1118407..42ac119e2 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -55,6 +55,12 @@ def _get_invalid_idx_value(cls): class QEffDynamicLayer(DynamicLayer): + def lazy_initialization(self, key_states: torch.Tensor): + self.dtype, self.device = key_states.dtype, key_states.device + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + self.is_initialized = True + def read_only(self, cache_kwargs): """ Reads the `key_states` and `value_states` for the layer. @@ -151,6 +157,7 @@ def write_only(self, key_states, value_states, cache_kwargs): self.keys = key_states self.values = value_states else: + # breakpoint() position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -185,11 +192,15 @@ def update( Return: A tuple containing the updated key and value states. """ + # breakpoint() # Update the cache + # if not self.is_initialized: + if self.keys is None: self.keys = key_states self.values = value_states k_out, v_out = self.keys, self.values + self.is_initialized = True else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -306,15 +317,48 @@ class QEffDynamicCache(DynamicCache): """ - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + def __init__( + self, + ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, + config=None, + offloading: bool = False, + offload_only_non_sliding: bool = False, + *args, + **kwargs, + ): # Remove layer_classes if present to avoid duplicate argument - kwargs.pop("layer_classes", None) + # breakpoint() + kwargs.pop("layers", None) from transformers.cache_utils import Cache # Import here to avoid circular import - Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + # breakpoint() + layers = [] + # If a config is passed, use it to infer the layer types and initialize accordingly + if len(layers) == 0: + Cache.__init__( + self, + layer_class_to_replicate=QEffDynamicLayer, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + # args=args, + # kwargs=kwargs, + ) + else: + Cache.__init__( + self, + layers=layers, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + # args=args, + # kwargs=kwargs, + ) + if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data): + # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data + layers.append(QEffDynamicLayer()) + # Update the layer with the data + _, _ = layers[layer_idx].update(key_states, value_states) def read_only(self, layer_idx, cache_kwargs): """ @@ -329,6 +373,7 @@ def read_only(self, layer_idx, cache_kwargs): Return: A tuple containing the updated key and value states. """ + # breakpoint() return self.layers[layer_idx].read_only(cache_kwargs) def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): @@ -394,6 +439,18 @@ def update3D( self.append_new_layers(layer_idx) return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs) + # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + # """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # # TODO: deprecate this function in favor of `cache_position` + # breakpoint() + # is_empty_layer = ( + # len(self.key_cache) == 0 # no cache in any layer + # or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + # or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + # ) + # layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return layer_seq_length + class QEffEncoderDecoderCache(EncoderDecoderCache): """ diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d44638aa0..c72854981 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -64,6 +64,8 @@ from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import ( AwqToMatmulNbitsTransform, + FP8BlockWiseDequantLinearToLinearTransform, + FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform, FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, Mxfp4GptOssExpertDequantizeTransform, @@ -999,6 +1001,8 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): _pytorch_transforms = [ AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, + FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform, + FP8BlockWiseDequantLinearToLinearTransform, CustomOpsTransform, KVCacheTransform, VlmKVOffloadTransform, @@ -1666,6 +1670,7 @@ def kv_offload_generate( AssertionError If `generation_len` is not greater than zero. """ + # breakpoint() if not self.lang_model.qpc_path: raise TypeError("Please run compile API for language model first!") @@ -1697,7 +1702,7 @@ def kv_offload_generate( [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] ) - + # breakpoint() input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float @@ -1743,7 +1748,7 @@ def kv_offload_generate( vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - + # breakpoint() if "position_ids" in inputs: lang_inputs["position_ids"] = inputs["position_ids"] lang_inputs.pop("attention_mask") @@ -1755,7 +1760,7 @@ def kv_offload_generate( not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" if not_mllama: lang_inputs["image_idx"] = np.array([[0]]) - + # breakpoint() if self.vision_model.qpc_path: vision_session.deactivate() lang_session.activate() @@ -1770,7 +1775,7 @@ def kv_offload_generate( lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] lang_start = perf_counter() - + # breakpoint() # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): @@ -1802,7 +1807,7 @@ def kv_offload_generate( ) if not_mllama: lang_session.skip_buffers(vision_outputs.keys()) - + # breakpoint() # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index f1daf3014..90ce1de15 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -116,11 +116,6 @@ MistralModel, MistralRMSNorm, ) -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3ForConditionalGeneration, - Mistral3Model, - Mistral3RMSNorm, -) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -143,13 +138,6 @@ MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel -from transformers.models.olmo2.modeling_olmo2 import ( - Olmo2Attention, - Olmo2DecoderLayer, - Olmo2ForCausalLM, - Olmo2Model, - Olmo2RMSNorm, -) from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import ( Phi3Attention, @@ -158,7 +146,6 @@ Phi3Model, Phi3RMSNorm, ) -from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -171,7 +158,6 @@ Qwen2_5_VLAttention, Qwen2_5_VLDecoderLayer, Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLModel, Qwen2_5_VLTextModel, Qwen2_5_VLVisionAttention, ) @@ -185,14 +171,15 @@ Qwen3Model, Qwen3RMSNorm, ) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - Qwen3MoeAttention, - Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeModel, - Qwen3MoeRMSNorm, - Qwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock, +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + Qwen3VLMoeTextAttention, + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextModel, + Qwen3VLMoeTextRMSNorm, + Qwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel, ) from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, @@ -346,11 +333,6 @@ QEffMistralForCausalLM, QEffMistralModel, ) -from QEfficient.transformers.models.mistral3.modeling_mistral3 import ( - QEffMistral3ForConditionalGeneration, - QEffMistral3Model, - QEffPixtralVisionModel, -) from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -371,25 +353,12 @@ QEffMllamaTextSelfAttention, QEffMllamaVisionModel, ) -from QEfficient.transformers.models.molmo.modeling_molmo import ( - QEffMolmo, - QEffMolmoBlock, - QEffMolmoModel, - QEffMolmoSequentialBlock, - QEffMultiHeadDotProductAttention, -) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel, ) -from QEfficient.transformers.models.olmo2.modeling_olmo2 import ( - QEffOlmo2Attention, - QEffOlmo2DecoderLayer, - QEffOlmo2ForCausalLM, - QEffOlmo2Model, -) from QEfficient.transformers.models.phi.modeling_phi import ( QEffPhiAttention, QEffPhiDecoderLayer, @@ -412,10 +381,9 @@ QEffQwen2_5_VisionTransformerPretrainedModel, QEffQwen2_5_VLAttention, QEffQwen2_5_VLDecoderLayer, - QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, + # QEffQwen2_5_VLModel, QEffQwen2_5_VLVisionAttention, - QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -433,6 +401,16 @@ QEffQwen3MoeRotaryEmbedding, QEffQwen3MoeSparseMoeBlock, ) +from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + QEffQwen3VLMoeForConditionalGeneration, + QEffQwen3VLMoeModel, + QEffQwen3VLMoeTextAttention, + QEffQwen3VLMoeTextDecoderLayer, + QEffQwen3VLMoeTextModel, + # QEffQwen3VLMoeTextSparseMoeBlock, + QEffQwen3VLMoeVisionAttention, + QEffQwen3VLMoeVisionModel, +) from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, QEFFStarcoder2DecoderLayer, @@ -467,7 +445,6 @@ class CustomOpsTransform(ModuleMappingTransform): LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, - Mistral3RMSNorm: CustomRMSNormAIC, MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, @@ -475,11 +452,11 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen2_5RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, - PixtralRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, - Qwen3MoeRMSNorm: CustomRMSNormAIC, + Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, + # Qwen3VLTextRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, - Olmo2RMSNorm: CustomRMSNormAIC, + # Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, } @@ -532,12 +509,12 @@ class KVCacheTransform(ModuleMappingTransform): GemmaModel: QEffGemmaModel, GemmaForCausalLM: QEffGemmaForCausalLM, # Qwen3Moe - Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, - Qwen3MoeModel: QEffQwen3MoeModel, - Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, - Qwen3MoeAttention: QEffQwen3MoeAttention, - Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + # Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, + # Qwen3MoeModel: QEffQwen3MoeModel, + # Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, + # Qwen3MoeAttention: QEffQwen3MoeAttention, + # Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, + # Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, # Gemma2 Gemma2Attention: QEffGemma2Attention, Gemma2DecoderLayer: QEffGemma2DecoderLayer, @@ -585,9 +562,6 @@ class KVCacheTransform(ModuleMappingTransform): MistralDecoderLayer: QEffMistralDecoderLayer, MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, - # Mistral3 - Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, - Mistral3Model: QEffMistral3Model, # Mixtral MixtralAttention: QEffMixtralAttention, MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, @@ -609,26 +583,38 @@ class KVCacheTransform(ModuleMappingTransform): PhiDecoderLayer: QEffPhiDecoderLayer, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, - # Pixtral - PixtralVisionModel: QEffPixtralVisionModel, # Qwen2 Qwen2Attention: QEffQwen2Attention, Qwen2DecoderLayer: QEffQwen2DecoderLayer, Qwen2Model: QEffQwen2Model, Qwen2ForCausalLM: QEffQwen2ForCausalLM, + # Qwen2.5 VL + Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, + # Qwen2_5_VLModel: QEffQwen2_5_VLModel, + Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Qwen3 Qwen3Attention: QEffQwen3Attention, Qwen3DecoderLayer: QEffQwen3DecoderLayer, Qwen3Model: QEffQwen3Model, Qwen3ForCausalLM: QEffQwen3ForCausalLM, # Qwen2.5 VL - Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, - Qwen2_5_VLModel: QEffQwen2_5_VLModel, + # Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, + # Qwen2_5_VLModel: QEffQwen2_5_VLModel, Qwen2_5_VLAttention: QEffQwen2_5_VLAttention, Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer, Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention, - Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, + # Qwen3vlmoe + Qwen3VLMoeForConditionalGeneration: QEffQwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel: QEffQwen3VLMoeModel, + Qwen3VLMoeTextAttention: QEffQwen3VLMoeTextAttention, + Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, + Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, + Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, + # Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, + # Grok1 + # Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, @@ -639,11 +625,6 @@ class KVCacheTransform(ModuleMappingTransform): GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, - # Olmo2 - Olmo2Attention: QEffOlmo2Attention, - Olmo2DecoderLayer: QEffOlmo2DecoderLayer, - Olmo2Model: QEffOlmo2Model, - Olmo2ForCausalLM: QEffOlmo2ForCausalLM, # Whisper encoder and decoder layers WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, WhisperAttention: QEffWhisperAttention, @@ -719,7 +700,7 @@ class SpDTransform: # Llama QEffLlamaForCausalLM, QEffQwen2ForCausalLM, - QEffQwen3ForCausalLM, + # QEffQwen3ForCausalLM, } @classmethod @@ -785,7 +766,7 @@ class SamplerTransform: QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, - QEffQwen_2_5_vl_DecoderWrapper, + # QEffQwen_2_5_vl_DecoderWrapper, } @classmethod @@ -831,32 +812,6 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, - # Mapping for Molmo - "MolmoForCausalLM": { - "forward": QEffMolmoModel.forward, - "get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder, - "get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder, - "get_specializations": QEffMolmoModel.get_specializations, - "get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes, - "get_output_names": QEffMolmoModel.get_output_names, - "get_dummy_inputs": QEffMolmoModel.get_dummy_inputs, - "get_inputs_info": QEffMolmoModel.get_inputs_info, - }, - "RMSLayerNorm": {"forward": CustomRMSNormAIC.forward}, - # "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward}, - "Molmo": {"forward": QEffMolmo.forward}, - "MolmoSequentialBlock": { - "forward": QEffMolmoSequentialBlock.forward, - "attention": QEffMolmoBlock.attention, - "__qeff_init__": QEffMolmoBlock.__qeff_init__, - }, - "MolmoBlock": { - "attention": QEffMolmoBlock.attention, - "__qeff_init__": QEffMolmoBlock.__qeff_init__, - }, - "MultiHeadDotProductAttention": { - "forward": QEffMultiHeadDotProductAttention.forward, - }, # Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, "Grok1Model": { diff --git a/QEfficient/transformers/models/qwen3_vl_moe/__init__.py b/QEfficient/transformers/models/qwen3_vl_moe/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl_moe/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py new file mode 100644 index 000000000..8c2532b18 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -0,0 +1,1112 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + Qwen3VLMoeModelOutputWithPast, + Qwen3VLMoeTextAttention, + Qwen3VLMoeTextConfig, + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextModel, + Qwen3VLMoeTextRotaryEmbedding, + Qwen3VLMoeTextSparseMoeBlock, + Qwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel, + apply_rotary_pos_emb_vision, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffQwen3VLMoeTextRotaryEmbedding(Qwen3VLMoeTextRotaryEmbedding): + def __init__(self, config: Qwen3VLMoeTextConfig, device=None): + super().__init__(config, device) + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + position_ids = torch.arange(seq_len, device=device, dtype=torch.long) + position_ids = position_ids.unsqueeze(0).expand(3, 1, -1) # (3, 1, seq_len) + + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, 1, -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # (3, 1, 1, seq_len) + + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + + freqs_interleaved = self._apply_interleaved_mrope_cached(freqs, self.mrope_section) + emb = torch.cat((freqs_interleaved, freqs_interleaved), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * self.attention_scaling).squeeze(0).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.attention_scaling).squeeze(0).to(dtype), persistent=False + ) + + def _apply_interleaved_mrope_cached(self, freqs, mrope_section): + freqs_t = freqs[0].clone() # (bs, seq_len, head_dim // 2) + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def forward(self, x, position_ids, seq_len=None): + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + seq_len = position_ids.shape[-1] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + cos = self.cos_cached[:seq_len].to(dtype=x.dtype) + sin = self.sin_cached[:seq_len].to(dtype=x.dtype) + + if position_ids.shape[1] > 1: + cos = cos.unsqueeze(0).expand(position_ids.shape[1], -1, -1) + sin = sin.unsqueeze(0).expand(position_ids.shape[1], -1, -1) + else: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + return cos, sin + + +class QEffQwen3VLMoeVisionModel(Qwen3VLMoeVisionModel): + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + max_hw = max(grid_thw.shape) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + bs, num_frames, height, width = grid_thw.shape + grid_thw = (torch.tensor(grid_thw.shape, dtype=torch.int64)).unsqueeze(0) + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + pos_ids = coords + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + bs, t, h, w = grid_thw.shape + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + max_t = torch.tensor(self.num_grid_per_side - 1, device=h_idxs.device) + + h_idxs_ceil = torch.minimum(h_idxs_floor + 1, max_t) # working + w_idxs_ceil = torch.minimum(w_idxs_floor + 1, max_t) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + idx_tensor = torch.stack(indices, dim=0).to(dtype=torch.long, device=self.pos_embed.weight.device) # [4, h*w] + + weight_tensor = torch.stack(weights, dim=0).to( + dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + pos_embed = patch_pos_embeds[0] + pos_embed = pos_embed.repeat(t, 1) + + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + x_expanded = patch_pos_embeds.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1, -1) + patch_pos_embeds = x_expanded.reshape(-1, patch_pos_embeds.size(1)) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + bs, t, h, w = grid_thw.shape + + t = torch.arange(t, t + 1).squeeze().expand(bs) + h = torch.arange(h, h + 1).squeeze().expand(bs) + w = torch.arange(w, w + 1).squeeze().expand(bs) + + cu_seqlens = (h * w).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens]) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + return hidden_states, deepstack_feature_lists + + +class QEffQwen3VLMoeVisionAttention(Qwen3VLMoeVisionAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + + # Create index grids + seq_len = attention_mask.shape[-1] + rows = torch.arange(seq_len).view(1, -1) + cols = torch.arange(seq_len).view(-1, 1) + + # Prepare start and end indices + start = cu_seqlens[:-1].view(-1, 1, 1) + end = cu_seqlens[1:].view(-1, 1, 1) + + # Create block masks using broadcasting + row_mask = (rows >= start) & (rows < end) + col_mask = (cols >= start) & (cols < end) + block_mask = row_mask & col_mask # shape: (num_blocks, seq_len, seq_len) + + # Combine all blocks into one mask + final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask[block_mask.any(dim=0)] = 0 + + final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) + + attention_mask[0] = final_mask + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffQwen3VLMoeTextAttention(Qwen3VLMoeTextAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids[1:], self.config.rope_scaling["mrope_section"] + ) + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_values=past_key_values, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_values + + +class QEffQwen3VLMoeTextDecoderLayer(Qwen3VLMoeTextDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + comp_ctx_lengths=comp_ctx_lengths, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states[0] + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + return outputs + + +class QEffQwen3VLMoeTextModel(Qwen3VLMoeTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + visual_pos_masks: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.config.use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids[0], target_length=target_length, sliding_window=None + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + layer_idx = 0 + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class QEffQwen3VLMoeModel(Qwen3VLMoeModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen3VLMoeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen3VLEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.visual + + def forward(self, pixel_values, image_grid_thw): + image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + return image_embeds + + +class QEffQwen3VLDecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.model + + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + ): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + outputs = self.model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, image_idx, outputs.past_key_values + + +class QEffQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + router_logits = self.gate(hidden_states) # [T, E] + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + gate_proj_up_w = self.experts.gate_up_proj.requires_grad_(False)[top_i.flatten()] + down_proj_w = self.experts.down_proj.requires_grad_(False)[top_i.flatten()] + + expert_in = hidden_states.unsqueeze(1).expand(-1, self.top_k, -1).contiguous().view(-1, 1, H) + gate_up = torch.bmm(expert_in, gate_proj_up_w) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + intermediate = up * self.experts.act_fn(gate) + experts_out = torch.bmm(intermediate, down_proj_w) + experts_out = experts_out.view(B * S, self.top_k, H) + experts_out = experts_out * top_w.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + return experts_out.view(B, S, H), router_logits + + +class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffQwen3VLEncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffQwen3VLDecoderWrapper(self) + + # def forward( + # self, + # input_ids, + # position_ids, + # past_key_values, + # pixel_values:Optional[torch.FloatTensor] = None, + # image_idx:Optional[torch.LongTensor] = None, + # comp_ctx_lengths: Optional[List[int]] = None, + # batch_index: Optional[torch.LongTensor] = None, + # image_grid_thw=None, + # ): + # image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] + # bs = image_grid_thw.shape[0] + # split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + + # inputs_embeds = self.model.get_input_embeddings()(input_ids) + # B, N, C = inputs_embeds.shape + # selected = input_ids == self.model.config.image_token_id + # indices1 = selected.to(torch.int64).cumsum(1) - 1 + # indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + # indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + # image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + # image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + # outputs = self.language_model( + # inputs_embeds=inputs_embeds, + # position_ids=position_ids, + # past_key_values=past_key_values, + # comp_ctx_lengths=comp_ctx_lengths, + # batch_index=batch_index, + # use_cache=True, + # ) + # logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + # hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + # logits = self.lm_head(hidden_states) + # image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + # return logits, image_embeds, image_idx, outputs.past_key_values + + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + vision_size = 187 + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.vision_config.out_hidden_size, + ) + inputs_shapes["image_grid_thw"] = (1, 1, 22, 34) + inputs_shapes["position_ids"] = ( + 3, + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = (748, 1536) + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + .unsqueeze(0) + .repeat(4, 1, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + # Add data for KV + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: None, + height: int = None, + width: int = None, + num_frames: int = 1, + kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + **compiler_options, + ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if height is None or width is None: + height = 1365 + width = 2048 + logger.warning( + "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config" + ) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + channel = 3 + patch_size = self.config.vision_config.patch_size + temporal_patch_size = self.config.vision_config.temporal_patch_size + + IMAGE_FACTOR = 32 + MIN_PIXELS = 64 * 32 * 32 + MAX_PIXELS = 16384 * 32 * 32 + MAX_RATIO = 200 + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + resized_height, resized_width = smart_resize(height=height, width=width) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + grid_height = grid_h * grid_w + grid_width = patch_size * patch_size * temporal_patch_size * channel + vision_size = grid_height // 4 + vision_size = vision_size * num_frames + grid_height = grid_height * batch_size + + vision = [ + { + "batch_size": batch_size, + "vision_size": vision_size, + "grid_height": grid_height, + "grid_width": grid_width, + "grid_h": grid_h, + "grid_w": grid_w, + } + ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") + return lang, compiler_options + + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + vision_dynamic_axes = { + "pixel_values": {0: "grid_height", 1: "grid_width"}, + "image_grid_thw": {0: "batch_size", 2: "grid_h", 3: "grid_w"}, + } + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {1: "batch_size", 2: "seq_len"}, + "vision_embeds": {0: "batch_size", 1: "vision_size"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + + dynamic_axes = {} + + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + lang_dynamic_axes.pop("vision_embeds") + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=1): + input_ids_length = inputs["input_ids"].shape[1] + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + pos_ids, rope_deltas = self.model.get_rope_index( + inputs["input_ids"], + None if "image_grid_thw" not in inputs else inputs["image_grid_thw"], + video_grid_thw=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + ] diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 69d6380f0..f97bfe998 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,15 +7,22 @@ import torch from torch import nn +from transformers import AutoConfig from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ -from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( + FP8BlockWiseDequantLinear, + FP8BlockWiseDequantQwen3VLMoeTextExperts, + FP8DeQuantLinear, +) from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts from QEfficient.transformers.quantizers.quantizer_utils import ( + blockwise_dequantize, convert_moe_packed_tensors, dequantize_gptq, unpack_weights, @@ -146,3 +153,57 @@ def mutate(cls, original_module, parent_module): dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias dequant_module.down_proj_bias = original_module.down_proj_bias return dequant_module + + +class FP8BlockWiseDequantLinearToLinearTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an FP8BlockWiseDequantLinear module and replace with a regular Linear layer + """ + + _match_class = FP8BlockWiseDequantLinear + + @classmethod + def mutate(cls, original_module, parent_module): + # -- de-quantizing the weights -- + dequant_weights = blockwise_dequantize( + original_module.weight, original_module.weight_scale_inv, original_module.weight_block_size + ) + dequant_linear_layer = nn.Linear( + original_module.in_features, original_module.out_features, bias=original_module.bias is not None + ) + dequant_linear_layer.weight = torch.nn.Parameter(dequant_weights) + if original_module.bias is not None: + dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) + return dequant_linear_layer + + +class FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform(ModuleMutatorTransform): + _match_class = FP8BlockWiseDequantQwen3VLMoeTextExperts + _model_type = "qwen3_vl_moe" + + @classmethod + def mutate(cls, original_module, parent_module): + config = AutoConfig.for_model(cls._model_type).text_config + config.num_experts = original_module.num_experts + config.intermediate_size = original_module.intermediate_size + config.hidden_size = original_module.hidden_size + assert original_module.act_fn.__class__.__name__ == "SiLUActivation", ( + "Only SiLU activation is supported for now." + ) + assert config.hidden_act == "silu", "expected silu act fn, something changed in transformers code" + dequant_module = Qwen3VLMoeTextExperts(config) + dequant_module.gate_up_proj = torch.nn.Parameter( + blockwise_dequantize( + original_module.gate_up_proj, + original_module.gate_up_proj_scale_inv, + original_module.weights_block_size, + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + blockwise_dequantize( + original_module.down_proj, + original_module.down_proj_scale_inv, + original_module.weights_block_size, + ) + ) + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index e7e14166d..382677bcf 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -10,10 +10,11 @@ from typing import List import torch +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.utils.quantization_config import CompressedTensorsConfig, QuantizationConfigMixin, QuantizationMethod -from QEfficient.transformers.quantizers.quantizer_utils import get_keys_to_not_convert +from QEfficient.transformers.quantizers.quantizer_utils import blockwise_dequantize, get_keys_to_not_convert from QEfficient.utils.logging_utils import logger FP8_DTYPE = torch.float8_e4m3fn @@ -128,6 +129,118 @@ def forward(self, x): return out +class FP8BlockWiseDequantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight_block_size: List[int], + bias: bool = False, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_block_size = weight_block_size + + self.register_buffer( + "weight", + torch.empty( + (out_features, in_features), dtype=FP8_DTYPE + ), # This is fixed for now and only e4m3fn quantization is prominent + ) + + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + ), + ) + else: + self.bias = None + + @classmethod + def for_fp8_layer_with_blocksize(cls, in_features, out_features, weight_block_size, fmt, bias): + fp8_dequant_layer = cls(in_features, out_features, weight_block_size, bias) + assert fmt == "e4m3", "e5m2 is not supposed yet!!" + assert (in_features % weight_block_size[0]) == 0 and (out_features % weight_block_size[1]) == 0, ( + "weight shape is not divisible by block sizes in either rows or columns or both dimensions, \ + got in_features: {in_features}, out_features: {out_features}, weight_block_size: {weight_block_size}!!" + ) + fp8_dequant_layer.register_buffer( + "weight_scale_inv", + torch.empty( + (out_features // weight_block_size[0], in_features // weight_block_size[1]), dtype=torch.float32 + ), + ) + return fp8_dequant_layer + + def __repr__(self): + return f"FP8BlockWiseDequantLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias})" + + def forward(self, x): + with torch.no_grad(): + dequantized_weights = blockwise_dequantize(self.weight, self.weight_scale_inv, self.weight_block_size) + out = torch.matmul(x.float(), dequantized_weights.T) + out = out + self.bias if self.bias is not None else out + + return out + + +class FP8BlockWiseDequantQwen3VLMoeTextExperts(torch.nn.Module): + def __init__(self, num_experts, moe_intermediate_size, hidden_size, act_fn, weights_block_size): + super().__init__() + self.num_experts = num_experts + self.intermediate_size = moe_intermediate_size + self.hidden_size = hidden_size + self.expert_dim = self.intermediate_size + self.weights_block_size = weights_block_size + r, c = weights_block_size + self.register_buffer( + "gate_up_proj", torch.empty((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=FP8_DTYPE) + ) + self.register_buffer( + "down_proj", torch.empty((self.num_experts, self.expert_dim, self.hidden_size), dtype=FP8_DTYPE) + ) + self.register_buffer( + "gate_up_proj_scale_inv", + torch.empty((self.num_experts, self.hidden_size // r, (2 * self.expert_dim) // c), dtype=torch.float32), + ) + self.register_buffer( + "down_proj_scale_inv", + torch.empty((self.num_experts, self.expert_dim // r, self.hidden_size // c), dtype=torch.float32), + ) + self.act_fn = act_fn + + @classmethod + def for_fp8_layer_with_blocksize(cls, old_module, weight_block_size, fmt): + assert fmt == "e4m3", "e5m2 is not supposed yet!!" + fp8_experts = cls( + num_experts=old_module.num_experts, + moe_intermediate_size=old_module.intermediate_size, + hidden_size=old_module.hidden_size, + act_fn=old_module.act_fn, + weights_block_size=weight_block_size, + ) + return fp8_experts + + def forward(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up_proj = blockwise_dequantize(self.gate_up_proj, self.gate_up_proj_inv_scale, self.weights_block_size) + down_proj = blockwise_dequantize(self.down_proj, self.down_proj_inv_scale, self.weights_block_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + class QEffFP8Config(QuantizationConfigMixin): def __init__( self, @@ -136,6 +249,8 @@ def __init__( ignored_layers: List[str] = None, kv_cache_scheme: str = None, run_compressed: bool = False, + fmt: str = None, + weight_block_size: List[int] = None, ): self.quant_method = quant_method self.activation_scheme = activation_scheme @@ -155,6 +270,52 @@ def __init__( ) self.quant_method = QEffExtendedQuantizationMethod.FP8 + self.fmt = fmt + self.weight_block_size = weight_block_size + + +def _replace_with_fp8_dequant_linear_and_experts_if_qwen( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False +): + current_key_name = [] if current_key_name is None else current_key_name + + for name, child_module in model.named_children(): + current_key_name.append(name) + + if isinstance(child_module, torch.nn.Linear) and name not in (modules_to_not_convert or []): + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + model._modules[name] = FP8BlockWiseDequantLinear.for_fp8_layer_with_blocksize( + child_module.in_features, + child_module.out_features, + quantization_config.weight_block_size, + quantization_config.fmt, + child_module.bias is not None, + ) + has_been_replaced = True + + if isinstance(child_module, Qwen3VLMoeTextExperts) and name not in (modules_to_not_convert or []): + # Replace the MoE experts + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + model._modules[name] = FP8BlockWiseDequantQwen3VLMoeTextExperts.for_fp8_layer_with_blocksize( + child_module, + quantization_config.weight_block_size, + quantization_config.fmt, + ) + has_been_replaced = True + + if len(list(child_module.children())) > 0: + _, has_been_replaced = _replace_with_fp8_dequant_linear_and_experts_if_qwen( + child_module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + + current_key_name.pop(-1) + return model, has_been_replaced class QEffFP8Quantizer(CompressedTensorsHfQuantizer): @@ -196,6 +357,12 @@ def _process_model_before_weight_loading(self, model, **kwargs): f"activations quantization strategy = {self.quantization_config.activation_scheme}, will be ignored and the layers will be run with de-quantized weights" ) + if self.quantization_config.weight_block_size is not None: + model, has_been_replaced = _replace_with_fp8_dequant_linear_and_experts_if_qwen( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return + # -- Defining local method as it uses lot of local variables -- def replace_linear_with_fp8_dequant_layer(module): for name, child_module in module.named_children(): @@ -218,7 +385,7 @@ def _process_model_after_weight_loading(self, model, **kwargs): def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]: return missing_keys - def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]: + def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str = None) -> List[str]: return unexpected_keys diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index 424692d08..4060a162d 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -7,6 +7,7 @@ import copy import math +from typing import List import torch from torch import nn @@ -446,3 +447,26 @@ def convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) out = out.to(dtype).permute(0, 2, 1).contiguous() return out + + +def blockwise_dequantize( + quantized: torch.Tensor, + scales: torch.Tensor, + block_size: List[int] = None, + **kwargs, +) -> dict[str, torch.Tensor]: + rows, cols = quantized.shape[-2:] + if block_size is None: + block_size = (quantized.shape[-2], quantized.shape[-1]) + + block_m, block_n = block_size + + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError(f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}).") + quantized = quantized.to(scales.dtype) + reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) + expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) + expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) + dequantized = reshaped * expanded_scales + + return dequantized.reshape(quantized.shape) diff --git a/examples/qwen3_vl.py b/examples/qwen3_vl.py new file mode 100644 index 000000000..6609dbe2f --- /dev/null +++ b/examples/qwen3_vl.py @@ -0,0 +1,144 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +model_id = "Qwen/Qwen3-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# For Testing Purpose Only +config.vision_config.depth = 1 +config.text_config.num_hidden_layers = 1 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +# breakpoint() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=1024, + width=1024, + mxfp6_matmul=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + # breakpoint() + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + # breakpoint() + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + # height=354, + # width=536, + height=1024, + width=1024, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + # image_url = "https://picsum.photos/id/237/536/354" + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe the image in details."}, + ], + }, + ] + + # messages_2 = [ + # { + # "role": "user", + # "content": [ + # {"type": "image", "image": image}, + # {"type": "text", "text": "Describe about the color of the dog."}, + # ], + # }, + # ] + + messages = [messages_1] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + # breakpoint() + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/qwen3_vl_moe.py b/examples/qwen3_vl_moe.py new file mode 100644 index 000000000..931cfe093 --- /dev/null +++ b/examples/qwen3_vl_moe.py @@ -0,0 +1,134 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# For Testing Purpose Only +config.vision_config.depth = 1 +config.text_config.num_hidden_layers = 1 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe all the colors seen in the image."}, + ], + }, + ] + + # messages_2 = [ + # { + # "role": "user", + # "content": [ + # {"type": "image", "image": image}, + # {"type": "text", "text": "Describe about the color of the dog."}, + # ], + # }, + # ] + + messages = [messages_1] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) From 5b3ac38f00b316b1cbf21576cb77a7bfbc26ab95 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Thu, 26 Feb 2026 11:27:20 +0530 Subject: [PATCH 03/14] Deepstack features for Qwen3VL dense (#807) Signed-off-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar Signed-off-by: Onkar Chougule Signed-off-by: Dhiraj Kumar Sah Co-authored-by: Dipankar Sarkar Co-authored-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Signed-off-by: Dipankar Sarkar --- .../transformers/models/modeling_auto.py | 15 +- .../transformers/models/qwen3_vl/__init__.py | 6 + .../models/qwen3_vl/modeling_qwen3_vl.py | 1086 +++++++++++++++++ examples/qwen3_vl.py | 37 +- 4 files changed, 1111 insertions(+), 33 deletions(-) create mode 100644 QEfficient/transformers/models/qwen3_vl/__init__.py create mode 100644 QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c72854981..bf7cfe742 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1517,13 +1517,19 @@ def compile( for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): custom_io_lang[output_name[: -len("_RetainedState")]] = ( - "float16" if "vision_embeds" in output_name else kv_cache_dtype + "float16" + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype ) # outputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype + custom_io_lang[output_name] = ( + "float16" + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype + ) self.lang_model._compile( compile_dir=compile_dir, compile_only=True, @@ -1702,7 +1708,6 @@ def kv_offload_generate( [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] ) - # breakpoint() input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float @@ -1748,7 +1753,6 @@ def kv_offload_generate( vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - # breakpoint() if "position_ids" in inputs: lang_inputs["position_ids"] = inputs["position_ids"] lang_inputs.pop("attention_mask") @@ -1760,7 +1764,6 @@ def kv_offload_generate( not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" if not_mllama: lang_inputs["image_idx"] = np.array([[0]]) - # breakpoint() if self.vision_model.qpc_path: vision_session.deactivate() lang_session.activate() @@ -1775,7 +1778,6 @@ def kv_offload_generate( lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] lang_start = perf_counter() - # breakpoint() # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): @@ -1807,7 +1809,6 @@ def kv_offload_generate( ) if not_mllama: lang_session.skip_buffers(vision_outputs.keys()) - # breakpoint() # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 diff --git a/QEfficient/transformers/models/qwen3_vl/__init__.py b/QEfficient/transformers/models/qwen3_vl/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py new file mode 100644 index 000000000..cc12e6f39 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -0,0 +1,1086 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLModelOutputWithPast, + Qwen3VLTextAttention, + Qwen3VLTextDecoderLayer, + Qwen3VLTextModel, + Qwen3VLVisionAttention, + Qwen3VLVisionModel, + apply_rotary_pos_emb_vision, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffQwen3VLVisionModel(Qwen3VLVisionModel): + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = max(grid_thw.shape) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + bs, num_frames, height, width = grid_thw.shape + grid_thw = (torch.tensor(grid_thw.shape, dtype=torch.int64)).unsqueeze(0) + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + pos_ids = coords + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + bs, t, h, w = grid_thw.shape + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + max_t = torch.tensor(self.num_grid_per_side - 1, device=h_idxs.device) + + h_idxs_ceil = torch.minimum(h_idxs_floor + 1, max_t) # working + w_idxs_ceil = torch.minimum(w_idxs_floor + 1, max_t) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + idx_tensor = torch.stack(indices, dim=0).to(dtype=torch.long, device=self.pos_embed.weight.device) # [4, h*w] + + weight_tensor = torch.stack(weights, dim=0).to( + dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + pos_embed = patch_pos_embeds[0] + pos_embed = pos_embed.repeat(t, 1) + + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + x_expanded = patch_pos_embeds.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1, -1) + patch_pos_embeds = x_expanded.reshape(-1, patch_pos_embeds.size(1)) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + bs, t, h, w = grid_thw.shape + + t = torch.arange(t, t + 1).squeeze().expand(bs) + h = torch.arange(h, h + 1).squeeze().expand(bs) + w = torch.arange(w, w + 1).squeeze().expand(bs) + + cu_seqlens = (h * w).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens]) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + return hidden_states, deepstack_feature_lists + + +class QEffQwen3VLVisionAttention(Qwen3VLVisionAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + + # Create index grids + seq_len = attention_mask.shape[-1] + rows = torch.arange(seq_len).view(1, -1) + cols = torch.arange(seq_len).view(-1, 1) + + # Prepare start and end indices + start = cu_seqlens[:-1].view(-1, 1, 1) + end = cu_seqlens[1:].view(-1, 1, 1) + + # Create block masks using broadcasting + row_mask = (rows >= start) & (rows < end) + col_mask = (cols >= start) & (cols < end) + block_mask = row_mask & col_mask # shape: (num_blocks, seq_len, seq_len) + + # Combine all blocks into one mask + final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask[block_mask.any(dim=0)] = 0 + + final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) + + attention_mask[0] = final_mask + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffQwen3VLTextAttention(Qwen3VLTextAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids[1:], self.config.rope_scaling["mrope_section"] + ) + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_values=past_key_values, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_values + + +class QEffQwen3VLTextDecoderLayer(Qwen3VLTextDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + comp_ctx_lengths=comp_ctx_lengths, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states[0] + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + return outputs + + +class QEffQwen3VLTextModel(Qwen3VLTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + visual_pos_masks: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.config.use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids[0], target_length=target_length, sliding_window=None + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + layer_idx = 0 + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + layer_idx += 1 + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return (hidden_states, past_key_values) + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ): + visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + mixed_embeds = hidden_states + visual_embeds + + local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states) + + return local_this + + +class QEffQwen3VLEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.visual + + def forward(self, pixel_values, image_grid_thw): + image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw) + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + deepstack_features = torch.stack( + [feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists], + dim=0, # new axis for "features" + ) + return image_embeds, deepstack_features + + +class QEffQwen3VLDecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.model + + def forward( + self, + input_ids, + vision_embeds, + deepstack_features, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + ): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + + num_features, bs, split_size, C = deepstack_features.shape + x = deepstack_features.reshape(num_features, bs * split_size, C) + deepstack_features_expanded = x[:, indices1, :] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + + image_mask = selected.clone() + + visual_pos_masks = None + deepstack_visual_embeds = None + + if image_mask is not None: + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_features_expanded + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values + + +class QEffQwen3VLModel(Qwen3VLModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffQwen3VLEncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffQwen3VLDecoderWrapper(self) + + def forward( + self, + input_ids, + position_ids, + past_key_values, + pixel_values: Optional[torch.FloatTensor] = None, + image_idx: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + image_grid_thw=None, + ): + image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw) + + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + # TODO: deepstack_features are not processed for single QPC setup yet. Will do if required. + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, image_embeds, image_idx, outputs.past_key_values + + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + # vision_size = 1024 + vision_size = 187 + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.vision_config.out_hidden_size, + ) + inputs_shapes["image_grid_thw"] = (1, 1, 22, 34) + inputs_shapes["position_ids"] = ( + 3, + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = (748, 1536) + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + inputs_shapes["deepstack_features"] = ( + len(self.config.vision_config.deepstack_visual_indexes), + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.vision_config.out_hidden_size, + ) + + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + .unsqueeze(0) + .repeat(4, 1, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32) + # Add data for KV + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: None, + height: int = None, + width: int = None, + time: int = 1, + # dimensions: List = None, + num_frames: int = 1, + kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + **compiler_options, + ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if height is None or width is None: + height = 1365 + width = 2048 + logger.warning( + "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config" + ) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + channel = 3 + patch_size = self.config.vision_config.patch_size + temporal_patch_size = self.config.vision_config.temporal_patch_size + + IMAGE_FACTOR = 32 + MIN_PIXELS = 64 * 32 * 32 + MAX_PIXELS = 16384 * 32 * 32 + MAX_RATIO = 200 + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + resized_height, resized_width = smart_resize(height=height, width=width) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + grid_height = grid_h * grid_w + grid_width = patch_size * patch_size * temporal_patch_size * channel + vision_size = grid_height // 4 + vision_size = vision_size * num_frames * time + grid_height = grid_height * time * batch_size + + vision = [ + { + "batch_size": batch_size, + "vision_size": vision_size, + "grid_height": grid_height, + "grid_width": grid_width, + "time": time, + "grid_h": grid_h, + "grid_w": grid_w, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") + return lang, compiler_options + + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + vision_dynamic_axes = { + "pixel_values": {0: "grid_height", 1: "grid_width"}, + "image_grid_thw": {0: "batch_size", 1: "time", 2: "grid_h", 3: "grid_w"}, + "deepstack_features": {0: "num_feature_layers", 1: "batch_size", 2: "vision_size"}, + } + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {1: "batch_size", 2: "seq_len"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, + "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + + dynamic_axes = {} + + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + lang_dynamic_axes.pop("vision_embeds") + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + vision_output_names.append("deepstack_features") + lang_output_names = ["logits"] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + lang_output_names.insert(2, "deepstack_features_RetainedState") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=1): + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + pos_ids, rope_deltas = self.model.get_rope_index( + inputs["input_ids"], + None if "image_grid_thw" not in inputs else inputs["image_grid_thw"], + video_grid_thw=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + ] diff --git a/examples/qwen3_vl.py b/examples/qwen3_vl.py index 6609dbe2f..f955de4f4 100644 --- a/examples/qwen3_vl.py +++ b/examples/qwen3_vl.py @@ -6,28 +6,20 @@ # ----------------------------------------------------------------------------- import requests -import transformers from PIL import Image from qwen_vl_utils import process_vision_info from transformers import AutoConfig, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText -# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" model_id = "Qwen/Qwen3-VL-32B-Instruct" config = AutoConfig.from_pretrained(model_id) -# For Testing Purpose Only -config.vision_config.depth = 1 -config.text_config.num_hidden_layers = 1 - qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config ) -# breakpoint() -tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) -### use skip_vision=Ture, if want to run only text, ow false ### +### use skip_vision=Ture, if want to run only text, else false ### skip_vision = False if skip_vision: @@ -41,8 +33,8 @@ ctx_len=4096, num_cores=16, num_devices=4, - height=1024, - width=1024, + height=354, + width=536, mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, @@ -67,13 +59,11 @@ return_dict=True, return_tensors="pt", ) - # breakpoint() inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - # breakpoint() - streamer = TextStreamer(tokenizer) + streamer = TextStreamer(processor.tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) - print(tokenizer.batch_decode(output.generated_ids)) + print(processor.tokenizer.batch_decode(output.generated_ids)) print(output) else: @@ -85,10 +75,8 @@ ctx_len=4096, num_cores=16, num_devices=4, - # height=354, - # width=536, - height=1024, - width=1024, + height=354, + width=536, mxfp6_matmul=True, mxint8_kv_cache=True, aic_enable_depth_first=True, @@ -96,10 +84,8 @@ ) ### IMAGE + TEXT ### - # image_url = "https://picsum.photos/id/237/536/354" - image_url = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" - ) + image_url = "https://picsum.photos/id/237/536/354" + # image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" image = Image.open(requests.get(image_url, stream=True).raw) @@ -136,9 +122,8 @@ return_tensors="pt", ) inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - # breakpoint() - streamer = TextStreamer(tokenizer) + streamer = TextStreamer(processor.tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) - print(tokenizer.batch_decode(output.generated_ids)) + print(processor.tokenizer.batch_decode(output.generated_ids)) print(output) From 40d8b2269f77b7163708192b849d08207aac8911 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Thu, 26 Feb 2026 11:39:24 +0530 Subject: [PATCH 04/14] Onboarding Qwen3VlMoe (#590) The Onboarding of Qwen3VlMoe --------- Signed-off-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar Signed-off-by: vtirumal Signed-off-by: Onkar Chougule Co-authored-by: vtirumal Co-authored-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Signed-off-by: Dipankar Sarkar --- QEfficient/generation/embedding_handler.py | 9 +- QEfficient/generation/vlm_generation.py | 36 ++++- QEfficient/transformers/cache_utils.py | 2 - .../transformers/models/pytorch_transforms.py | 5 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 79 ++++------- examples/qwen3_vl.py | 11 +- examples/qwen3_vl_moe/qwen3_vl_moe.py | 134 ++++++++++++++++++ .../qwen3_vl_moe_contnious_batching.py | 69 +++++++++ 9 files changed, 279 insertions(+), 68 deletions(-) create mode 100644 examples/qwen3_vl_moe/qwen3_vl_moe.py create mode 100644 examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index e07b5dd04..7845b400e 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -252,7 +252,6 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - # Process image and text inputs = self._processor(images=image, text=prompt, return_tensors="pt") - if ( hasattr(self._qeff_model.model.config, "model_type") and self._qeff_model.model.config.model_type == "qwen2_5_vl" @@ -260,6 +259,14 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - inputs = self._qeff_model.model.prepare_inputs_for_generation( inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] ) + + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen3_vl_moe" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) # Convert to float32 if needed if "pixel_values" in inputs: diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index adacc373e..43d660d0c 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -149,6 +149,9 @@ def __init__( self.is_qwen2_5_vl = ( hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" ) + self.is_qwen3_vl_moe=( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe" + ) self.qeff_model = qeff_model self.processor = processor self.tokenizer = tokenizer @@ -256,9 +259,10 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): outputs, position_ids, generation_len = self.run_prefill( next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) ) - if self.is_qwen2_5_vl: _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + elif self.is_qwen3_vl_moe: + _ = self.update_decode_inputs_qwen3_vl_moe(outputs,position_ids,generation_len,decode_batch_id) else: _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -283,6 +287,27 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id + def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + def _execute_chunked_prefill( self, lang_inputs: Dict[str, np.ndarray], @@ -583,12 +608,12 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) if self.is_qwen2_5_vl: self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) - + if self.is_qwen3_vl_moe: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) # Create prompt queue prompt_queue = deque(vision_prompts) start = perf_counter() - # Pre-process ALL vision inputs and cache them logger.info("Pre-processing all vision inputs...") for batch_id in range(min(self.full_batch_size, len(vision_prompts))): @@ -610,7 +635,6 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, # Reset prompt queue for prefill prompt_queue = deque(vision_prompts) - self.batch_index = None # Run prefill for all inputs using cached vision @@ -696,6 +720,10 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation self.update_decode_inputs_qwen2_5_vl( outputs, position_ids_decode, generation_len_final, decode_batch_id ) + if self.is_qwen3_vl_moe: + self.update_decode_inputs_qwen3_vl_moe( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) else: self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) else: diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 42ac119e2..8561bba60 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -327,11 +327,9 @@ def __init__( **kwargs, ): # Remove layer_classes if present to avoid duplicate argument - # breakpoint() kwargs.pop("layers", None) from transformers.cache_utils import Cache # Import here to avoid circular import - # breakpoint() layers = [] # If a config is passed, use it to infer the layer types and initialize accordingly if len(layers) == 0: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 90ce1de15..f4360f1a7 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -178,6 +178,7 @@ Qwen3VLMoeTextDecoderLayer, Qwen3VLMoeTextModel, Qwen3VLMoeTextRMSNorm, + Qwen3VLMoeTextSparseMoeBlock, Qwen3VLMoeVisionAttention, Qwen3VLMoeVisionModel, ) @@ -407,7 +408,7 @@ QEffQwen3VLMoeTextAttention, QEffQwen3VLMoeTextDecoderLayer, QEffQwen3VLMoeTextModel, - # QEffQwen3VLMoeTextSparseMoeBlock, + QEffQwen3VLMoeTextSparseMoeBlock, QEffQwen3VLMoeVisionAttention, QEffQwen3VLMoeVisionModel, ) @@ -612,7 +613,7 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, - # Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, + Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, # Grok1 # Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Starcoder2 diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d6bfbda81..630965790 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -591,7 +591,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + # kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 8c2532b18..8dd4e0b14 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -711,23 +711,29 @@ class QEffQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S - hidden_states = hidden_states.view(T, H) - router_logits = self.gate(hidden_states) # [T, E] - prob = F.softmax(router_logits, -1, dtype=torch.float) - top_w, top_i = torch.topk(prob, self.top_k, -1) - top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) - gate_proj_up_w = self.experts.gate_up_proj.requires_grad_(False)[top_i.flatten()] - down_proj_w = self.experts.down_proj.requires_grad_(False)[top_i.flatten()] - - expert_in = hidden_states.unsqueeze(1).expand(-1, self.top_k, -1).contiguous().view(-1, 1, H) - gate_up = torch.bmm(expert_in, gate_proj_up_w) - gate, up = gate_up[..., ::2], gate_up[..., 1::2] + x = hidden_states.view(T, H) + + router_logits = self.gate(x) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + top_w = top_w / top_w.sum(dim=1, keepdim=True) + top_w = top_w.to(x.dtype) + idx = top_i.reshape(-1) + w_up = self.experts.gate_up_proj.index_select(0, idx) + w_dn = self.experts.down_proj.index_select(0, idx) + + xk = x.unsqueeze(1).expand(-1, self.top_k, -1).contiguous() + xk = xk.view(-1, 1, H) + gate_up = torch.bmm(xk, w_up) + I2 = gate_up.size(-1) + half = I2 // 2 + gate, up = gate_up[..., :half], gate_up[..., half:] intermediate = up * self.experts.act_fn(gate) - experts_out = torch.bmm(intermediate, down_proj_w) - experts_out = experts_out.view(B * S, self.top_k, H) - experts_out = experts_out * top_w.unsqueeze(-1) - experts_out = experts_out.sum(dim=1) - return experts_out.view(B, S, H), router_logits + experts_out = torch.bmm(intermediate, w_dn) + experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1) + experts_out = experts_out.sum(dim=1).view(B, S, H) + + return experts_out, router_logits class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): @@ -737,44 +743,6 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen3VLDecoderWrapper(self) - # def forward( - # self, - # input_ids, - # position_ids, - # past_key_values, - # pixel_values:Optional[torch.FloatTensor] = None, - # image_idx:Optional[torch.LongTensor] = None, - # comp_ctx_lengths: Optional[List[int]] = None, - # batch_index: Optional[torch.LongTensor] = None, - # image_grid_thw=None, - # ): - # image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] - # bs = image_grid_thw.shape[0] - # split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) - - # inputs_embeds = self.model.get_input_embeddings()(input_ids) - # B, N, C = inputs_embeds.shape - # selected = input_ids == self.model.config.image_token_id - # indices1 = selected.to(torch.int64).cumsum(1) - 1 - # indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) - # indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - # image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] - # image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) - # outputs = self.language_model( - # inputs_embeds=inputs_embeds, - # position_ids=position_ids, - # past_key_values=past_key_values, - # comp_ctx_lengths=comp_ctx_lengths, - # batch_index=batch_index, - # use_cache=True, - # ) - # logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) - # hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] - # logits = self.lm_head(hidden_states) - # image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - # return logits, image_embeds, image_idx, outputs.past_key_values - def get_dummy_inputs( self, comp_ctx_lengths: Optional[List[int]] = None, @@ -1036,7 +1004,7 @@ def get_onnx_dynamic_axes( lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {1: "batch_size", 2: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } for i in range(num_layers): @@ -1102,6 +1070,7 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size= inputs["position_ids"] = F.pad( inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 ) + inputs.pop("image_grid_thw", None) return inputs def get_inputs_info(self): diff --git a/examples/qwen3_vl.py b/examples/qwen3_vl.py index f955de4f4..88dc346c1 100644 --- a/examples/qwen3_vl.py +++ b/examples/qwen3_vl.py @@ -18,6 +18,7 @@ qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config ) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) ### use skip_vision=Ture, if want to run only text, else false ### skip_vision = False @@ -60,7 +61,7 @@ return_tensors="pt", ) inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - streamer = TextStreamer(processor.tokenizer) + streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) print(processor.tokenizer.batch_decode(output.generated_ids)) @@ -77,6 +78,8 @@ num_devices=4, height=354, width=536, + # height=1024, + # width=1024, mxfp6_matmul=True, mxint8_kv_cache=True, aic_enable_depth_first=True, @@ -85,7 +88,9 @@ ### IMAGE + TEXT ### image_url = "https://picsum.photos/id/237/536/354" - # image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + # image_url = ( + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + # ) image = Image.open(requests.get(image_url, stream=True).raw) @@ -122,7 +127,7 @@ return_tensors="pt", ) inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - streamer = TextStreamer(processor.tokenizer) + streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) print(processor.tokenizer.batch_decode(output.generated_ids)) diff --git a/examples/qwen3_vl_moe/qwen3_vl_moe.py b/examples/qwen3_vl_moe/qwen3_vl_moe.py new file mode 100644 index 000000000..931cfe093 --- /dev/null +++ b/examples/qwen3_vl_moe/qwen3_vl_moe.py @@ -0,0 +1,134 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# For Testing Purpose Only +config.vision_config.depth = 1 +config.text_config.num_hidden_layers = 1 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe all the colors seen in the image."}, + ], + }, + ] + + # messages_2 = [ + # { + # "role": "user", + # "content": [ + # {"type": "image", "image": image}, + # {"type": "text", "text": "Describe about the color of the dog."}, + # ], + # }, + # ] + + messages = [messages_1] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py b/examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py new file mode 100644 index 000000000..f209ad87b --- /dev/null +++ b/examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py @@ -0,0 +1,69 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.vision_config.depth = 1 +config.text_config.num_hidden_layers = 1 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) From 8cffdab986f846d9ba6f00b4db5df0f249e98c7d Mon Sep 17 00:00:00 2001 From: Karthikeya Date: Thu, 26 Feb 2026 16:09:09 +0530 Subject: [PATCH 05/14] Support for Qwen3VL MOE Disagg mode (#808) Signed-off-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar Signed-off-by: vtirumal Co-authored-by: Dipankar Sarkar Signed-off-by: Dipankar Sarkar --- QEfficient/generation/embedding_handler.py | 2 +- QEfficient/generation/vlm_generation.py | 4 +- QEfficient/transformers/cache_utils.py | 3 - .../transformers/models/modeling_auto.py | 120 ++++++-- .../transformers/models/pytorch_transforms.py | 3 + .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 58 +++- examples/qwen3_vl.py | 1 + examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py | 272 ++++++++++++++++++ 8 files changed, 430 insertions(+), 33 deletions(-) create mode 100644 examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 7845b400e..13fcc6f93 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -259,7 +259,7 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - inputs = self._qeff_model.model.prepare_inputs_for_generation( inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] ) - + if ( hasattr(self._qeff_model.model.config, "model_type") and self._qeff_model.model.config.model_type == "qwen3_vl_moe" diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 43d660d0c..0df3e6511 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -149,7 +149,7 @@ def __init__( self.is_qwen2_5_vl = ( hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" ) - self.is_qwen3_vl_moe=( + self.is_qwen3_vl_moe = ( hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe" ) self.qeff_model = qeff_model @@ -262,7 +262,7 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): if self.is_qwen2_5_vl: _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) elif self.is_qwen3_vl_moe: - _ = self.update_decode_inputs_qwen3_vl_moe(outputs,position_ids,generation_len,decode_batch_id) + _ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id) else: _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 8561bba60..37215702a 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -157,7 +157,6 @@ def write_only(self, key_states, value_states, cache_kwargs): self.keys = key_states self.values = value_states else: - # breakpoint() position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -192,7 +191,6 @@ def update( Return: A tuple containing the updated key and value states. """ - # breakpoint() # Update the cache # if not self.is_initialized: @@ -371,7 +369,6 @@ def read_only(self, layer_idx, cache_kwargs): Return: A tuple containing the updated key and value states. """ - # breakpoint() return self.layers[layer_idx].read_only(cache_kwargs) def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index bf7cfe742..bbf5b7270 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1036,7 +1036,36 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + prefill_seq_len: Optional[int] = None, + prefill_only: bool = False, + enable_chunking: bool = False, + **kwargs, + ): """ Exports the language decoder component to ONNX format. @@ -1060,6 +1089,18 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ + if prefill_only: + if prefill_seq_len > 1: + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.hash_params["prefill_only"] = True + self.prefill(enable=True, enable_chunking=enable_chunking) + else: + self.hash_params["prefill_only"] = False + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + return self._export( inputs, output_names=output_names, @@ -1267,6 +1308,11 @@ def export( self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, + skip_vision: Optional[bool] = False, + skip_lang: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + prefill_only: bool = False, + enable_chunking: bool = False, **kwargs, ) -> str: """ @@ -1320,26 +1366,33 @@ def export( vocab_size=self.model.language_model.config.vocab_size, qaic_config=self.lang_model.model.qaic_config, ) + if not skip_vision: + self.vision_model.export( + inputs["vision"], + output_names["vision"], + dynamic_axes["vision"], + export_dir=export_dir, + offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, + ) - self.vision_model.export( - inputs["vision"], - output_names["vision"], - dynamic_axes["vision"], - export_dir=export_dir, - offload_pt_weights=False, - use_onnx_subfunctions=use_onnx_subfunctions, - ) - - offload_pt_weights = kwargs.get("offload_pt_weights", True) - self.lang_model.export( - inputs["lang"], - output_names["lang"], - dynamic_axes["lang"], - export_dir=export_dir, - offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, - ) + if prefill_only and prefill_seq_len > 1: + offload_pt_weights = False # to keep weight for decode onnx + else: + offload_pt_weights = kwargs.get("offload_pt_weights", True) + if not skip_lang: + self.lang_model.export( + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + ) return self.onnx_path def compile( @@ -1363,6 +1416,8 @@ def compile( skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, use_onnx_subfunctions: bool = False, + prefill_only=False, + enable_chunking=False, **compiler_options, ) -> str: """ @@ -1481,11 +1536,18 @@ def compile( if lang_onnx_path: self.lang_model.onnx_path = lang_onnx_path - if (self.vision_model.onnx_path is None and vision_onnx_path is None) or ( - self.lang_model.onnx_path is None and lang_onnx_path is None + if ( + (self.vision_model.onnx_path is None and vision_onnx_path is None) + or (self.lang_model.onnx_path is None and lang_onnx_path is None) + or prefill_only ): self.export( use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + skip_lang=skip_lang, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -1530,11 +1592,20 @@ def compile( if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) + + if prefill_only: + if prefill_seq_len > 1: + specializations = specializations["lang"][:1] # prefill + else: + specializations = specializations["lang"][-1:] # decoder + else: + specializations = specializations["lang"] + self.lang_model._compile( + onnx_path=self.lang_model.onnx_path, compile_dir=compile_dir, compile_only=True, - retained_state=True, - specializations=specializations["lang"], + specializations=specializations, convert_to_fp16=True, mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, @@ -1544,6 +1615,8 @@ def compile( use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) + if skip_vision and prefill_only: # for disagg serving + return self.lang_model.qpc_path return self.qpc_path def generate( @@ -1676,7 +1749,6 @@ def kv_offload_generate( AssertionError If `generation_len` is not greater than zero. """ - # breakpoint() if not self.lang_model.qpc_path: raise TypeError("Please run compile API for language model first!") diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index f4360f1a7..00215b15e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -403,6 +403,7 @@ QEffQwen3MoeSparseMoeBlock, ) from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, QEffQwen3VLMoeForConditionalGeneration, QEffQwen3VLMoeModel, QEffQwen3VLMoeTextAttention, @@ -658,6 +659,8 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, # Qwen3Moe QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, + # Qwen3 VL Moe + QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 8dd4e0b14..dbae48c4a 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -659,6 +659,15 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.visual + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.visual.blocks[0].__class__} + def forward(self, pixel_values, image_grid_thw): image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] bs = image_grid_thw.shape[0] @@ -671,7 +680,16 @@ class QEffQwen3VLDecoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model - self.language_model = self.model.model + self.language_model = self.model.model.language_model + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwen3VLDecoderWrapper} def forward( self, @@ -714,7 +732,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens x = hidden_states.view(T, H) router_logits = self.gate(x) - prob = F.softmax(router_logits, dim=-1, dtype=torch.float) + prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) top_w = top_w / top_w.sum(dim=1, keepdim=True) top_w = top_w.to(x.dtype) @@ -736,6 +754,40 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return experts_out, router_logits +class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + act = getattr(self.experts, "act_fn", F.silu) + + router_logits = self.gate(x) # [T, E] + prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] + top_w = top_w / top_w.sum(dim=-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + + # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] + W_up = self.experts.gate_up_proj + W_dn = self.experts.down_proj + E, H_w, twoI = W_up.shape + I2 = twoI // 2 + routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E] + routing_weights.scatter_(1, top_i, top_w) + expert_out = x.new_zeros((T, H)) + for e in range(E): + rw = routing_weights[:, e].unsqueeze(-1) # [T, 1] + # Split fused [H, 2I] -> [H, I] + [H, I] + W_gate_e = W_up[e, :, :I2] + W_up_e = W_up[e, :, I2:] + W_dn_e = W_dn[e, :, :] + gate = x @ W_gate_e + up = x @ W_up_e + down = (up * act(gate)) @ W_dn_e + expert_out.add_(down * rw) + return expert_out.view(B, S, H), router_logits + + class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): def get_qeff_vision_encoder(self): return QEffQwen3VLEncoderWrapper(self) diff --git a/examples/qwen3_vl.py b/examples/qwen3_vl.py index 88dc346c1..75bed1630 100644 --- a/examples/qwen3_vl.py +++ b/examples/qwen3_vl.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import requests +import transformers from PIL import Image from qwen_vl_utils import process_vision_info from transformers import AutoConfig, AutoProcessor, TextStreamer diff --git a/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py new file mode 100644 index 000000000..0e308c1e7 --- /dev/null +++ b/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -0,0 +1,272 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from time import perf_counter + +import numpy as np +import requests +import torch +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# TODO clean up this script +# For Testing Purpose Only +# config.vision_config.depth = 1 +# config.text_config.num_hidden_layers = 1 +num_devices = 4 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +PREFILL_SEQ_LEN = 128 +CTX_LEN = 4096 +BS = 1 +torch.manual_seed(0) + +skip_vision = True + +prefill_qpc_path = qeff_model.compile( + batch_size=BS, + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + height=354, + width=536, + num_cores=16, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + retain_full_kv=True, + split_retained_state_io=True, + retained_state=True, + mos=1, + aic_enable_depth_first=True, + prefill_only=True, + enable_chunking=True, + skip_vision=True, + use_onnx_subfunctions=False, +) + + +decode_qpc_path = qeff_model.compile( + batch_size=BS, + prefill_seq_len=1, + ctx_len=CTX_LEN, + height=354, + width=536, + num_cores=16, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + retain_full_kv=True, + split_retained_state_io=True, + retained_state=True, + mos=1, + aic_enable_depth_first=True, + prefill_only=True, + enable_chunking=True, + skip_vision=True, + use_onnx_subfunctions=False, +) + +if skip_vision: # for only LLM with DA + lang_prefill_session = QAICInferenceSession(prefill_qpc_path) + lang_decode_session = QAICInferenceSession(decode_qpc_path) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] +else: + vision_qpc_path = qeff_model.compile( + batch_size=BS, + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + height=354, + width=536, + num_cores=16, + num_devices=num_devices, + retained_state=True, + mos=1, + aic_enable_depth_first=True, + # prefill_only=True, + # enable_chunking=True, + skip_vision=skip_vision, + skip_lang=True, + use_onnx_subfunctions=False, + ) + vision_session = QAICInferenceSession(vision_qpc_path) + lang_prefill_session = QAICInferenceSession(prefill_qpc_path) + lang_decode_session = QAICInferenceSession(decode_qpc_path) + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + image = Image.open(requests.get(image_url, stream=True).raw) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe all the colors seen in the image."}, + ], + }, + ] + + +########################### example for inference + +messages = [messages] * BS + +texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + +image_inputs, video_inputs = process_vision_info(messages) +inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", +) +inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=PREFILL_SEQ_LEN, batch_size=BS) + +pad_token_id = 1 +input_len = inputs["attention_mask"].sum(1, keepdims=True) +input_ids_length = inputs["input_ids"].shape[1] +num_chunks = -(input_ids_length // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +generation_len = CTX_LEN - input_len.max() +print(f"generation_len : {generation_len}") +generated_ids = np.full((BS, generation_len + 1), pad_token_id) + + +inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, +) +inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 +) + +for k, v in inputs.items(): + inputs[k] = np.array(v) + +vision_inputs = { + k: v + for k, v in inputs.items() + if k in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} +} + +vision_inputs_fp16 = {"pixel_values", "image_masks"} +vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + +vision_start = perf_counter() +vision_outputs = {} +if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) +vision_end = perf_counter() + +# TODO : pass vision_embeds_RetainedState to prefill +# vision_outputs["vision_embeds_RetainedState"] +# *** KeyError: 'vision_embeds_RetainedState' + +lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} +if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") +else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + +lang_inputs["image_idx"] = np.array([[0]]) + + +# RUN prefill +lang_start = perf_counter() + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][..., i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + outputs = lang_prefill_session.run(chunk_inputs) + for i in range(config.text_config.num_hidden_layers): + lang_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + lang_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] + + chunk_inputs["image_idx"] = outputs["image_idx_output"] +prefill_time = perf_counter() - lang_start + vision_end - vision_start +print(f"Prefill time :{prefill_time:.2f} secs") + + +all_outputs.append(np.argmax(outputs["logits"])) +decode_inputs = { + "input_ids": np.argmax(outputs["logits"]).reshape(1, 1), + "position_ids": np.max(lang_inputs["position_ids"]).reshape(1, 1) + 1, +} + +for i in range(config.text_config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] + decode_inputs["vision_embeds_RetainedState"] = outputs["vision_embeds_RetainedState"] + decode_inputs["image_idx_output"] = outputs["image_idx_output"] + +st = perf_counter() +decode_out = lang_decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") + +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + + +for i in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + loop_decode_inputs["vision_embeds_RetainedState"] = decode_out["vision_embeds_RetainedState"] + loop_decode_inputs["image_idx_output"] = decode_out["image_idx_output"] + + +st = perf_counter() +for i in range(generation_len - 2): + decode_out = lang_decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for j in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{j}"] = decode_out[f"past_key.{j}_RetainedState"] + loop_decode_inputs[f"past_value.{j}"] = decode_out[f"past_value.{j}_RetainedState"] + loop_decode_inputs["vision_embeds_RetainedState"] = decode_out["vision_embeds_RetainedState"] + loop_decode_inputs["image_idx_output"] = decode_out["image_idx_output"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = perf_counter() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"\noutput\n{tokenizer.decode(all_outputs)}") From 0d4b73e412d7eb4aba36f881cb1aa30e7fe8cc77 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 27 Feb 2026 06:54:53 +0000 Subject: [PATCH 06/14] Pytorch Transform Fix Signed-off-by: Dipankar Sarkar --- .../transformers/models/pytorch_transforms.py | 160 ++++++++++++++---- 1 file changed, 131 insertions(+), 29 deletions(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 00215b15e..00e7f2d23 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -116,6 +116,11 @@ MistralModel, MistralRMSNorm, ) +from transformers.models.mistral3.modeling_mistral3 import ( + Mistral3ForConditionalGeneration, + Mistral3Model, + Mistral3RMSNorm, +) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -138,6 +143,13 @@ MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel +from transformers.models.olmo2.modeling_olmo2 import ( + Olmo2Attention, + Olmo2DecoderLayer, + Olmo2ForCausalLM, + Olmo2Model, + Olmo2RMSNorm, +) from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import ( Phi3Attention, @@ -146,6 +158,7 @@ Phi3Model, Phi3RMSNorm, ) +from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -158,6 +171,7 @@ Qwen2_5_VLAttention, Qwen2_5_VLDecoderLayer, Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, Qwen2_5_VLTextModel, Qwen2_5_VLVisionAttention, ) @@ -171,6 +185,25 @@ Qwen3Model, Qwen3RMSNorm, ) +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel, + Qwen3MoeRMSNorm, + Qwen3MoeRotaryEmbedding, + Qwen3MoeSparseMoeBlock, +) +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + Qwen3VLModel, + Qwen3VLTextAttention, + Qwen3VLTextDecoderLayer, + Qwen3VLTextModel, + Qwen3VLTextRMSNorm, + Qwen3VLVisionAttention, + Qwen3VLVisionModel, +) from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel, @@ -334,6 +367,11 @@ QEffMistralForCausalLM, QEffMistralModel, ) +from QEfficient.transformers.models.mistral3.modeling_mistral3 import ( + QEffMistral3ForConditionalGeneration, + QEffMistral3Model, + QEffPixtralVisionModel, +) from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -354,12 +392,25 @@ QEffMllamaTextSelfAttention, QEffMllamaVisionModel, ) +from QEfficient.transformers.models.molmo.modeling_molmo import ( + QEffMolmo, + QEffMolmoBlock, + QEffMolmoModel, + QEffMolmoSequentialBlock, + QEffMultiHeadDotProductAttention, +) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel, ) +from QEfficient.transformers.models.olmo2.modeling_olmo2 import ( + QEffOlmo2Attention, + QEffOlmo2DecoderLayer, + QEffOlmo2ForCausalLM, + QEffOlmo2Model, +) from QEfficient.transformers.models.phi.modeling_phi import ( QEffPhiAttention, QEffPhiDecoderLayer, @@ -382,9 +433,10 @@ QEffQwen2_5_VisionTransformerPretrainedModel, QEffQwen2_5_VLAttention, QEffQwen2_5_VLDecoderLayer, + QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, - # QEffQwen2_5_VLModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -402,6 +454,15 @@ QEffQwen3MoeRotaryEmbedding, QEffQwen3MoeSparseMoeBlock, ) +from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import ( + QEffQwen3VLForConditionalGeneration, + QEffQwen3VLModel, + QEffQwen3VLTextAttention, + QEffQwen3VLTextDecoderLayer, + QEffQwen3VLTextModel, + QEffQwen3VLVisionAttention, + QEffQwen3VLVisionModel, +) from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock, QEffQwen3VLMoeForConditionalGeneration, @@ -447,6 +508,7 @@ class CustomOpsTransform(ModuleMappingTransform): LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, + Mistral3RMSNorm: CustomRMSNormAIC, MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, @@ -454,11 +516,13 @@ class CustomOpsTransform(ModuleMappingTransform): Qwen2_5RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, + PixtralRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, - Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, - # Qwen3VLTextRMSNorm: CustomRMSNormAIC, + Qwen3MoeRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, - # Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, + Olmo2RMSNorm: CustomRMSNormAIC, + Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, + Qwen3VLTextRMSNorm: CustomRMSNormAIC, } @@ -511,12 +575,28 @@ class KVCacheTransform(ModuleMappingTransform): GemmaModel: QEffGemmaModel, GemmaForCausalLM: QEffGemmaForCausalLM, # Qwen3Moe - # Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, - # Qwen3MoeModel: QEffQwen3MoeModel, - # Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, - # Qwen3MoeAttention: QEffQwen3MoeAttention, - # Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, - # Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, + Qwen3MoeModel: QEffQwen3MoeModel, + Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, + Qwen3MoeAttention: QEffQwen3MoeAttention, + Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, + Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + Qwen3VLMoeForConditionalGeneration: QEffQwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel: QEffQwen3VLMoeModel, + Qwen3VLMoeTextAttention: QEffQwen3VLMoeTextAttention, + Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, + Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, + Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, + Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, + # Qwen3vl + Qwen3VLForConditionalGeneration: QEffQwen3VLForConditionalGeneration, + Qwen3VLModel: QEffQwen3VLModel, + Qwen3VLTextAttention: QEffQwen3VLTextAttention, + Qwen3VLTextDecoderLayer: QEffQwen3VLTextDecoderLayer, + Qwen3VLVisionAttention: QEffQwen3VLVisionAttention, + Qwen3VLVisionModel: QEffQwen3VLVisionModel, + Qwen3VLTextModel: QEffQwen3VLTextModel, # Gemma2 Gemma2Attention: QEffGemma2Attention, Gemma2DecoderLayer: QEffGemma2DecoderLayer, @@ -564,6 +644,9 @@ class KVCacheTransform(ModuleMappingTransform): MistralDecoderLayer: QEffMistralDecoderLayer, MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, + # Mistral3 + Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, + Mistral3Model: QEffMistral3Model, # Mixtral MixtralAttention: QEffMixtralAttention, MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, @@ -585,38 +668,26 @@ class KVCacheTransform(ModuleMappingTransform): PhiDecoderLayer: QEffPhiDecoderLayer, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, + # Pixtral + PixtralVisionModel: QEffPixtralVisionModel, # Qwen2 Qwen2Attention: QEffQwen2Attention, Qwen2DecoderLayer: QEffQwen2DecoderLayer, Qwen2Model: QEffQwen2Model, Qwen2ForCausalLM: QEffQwen2ForCausalLM, - # Qwen2.5 VL - Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, - # Qwen2_5_VLModel: QEffQwen2_5_VLModel, - Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Qwen3 Qwen3Attention: QEffQwen3Attention, Qwen3DecoderLayer: QEffQwen3DecoderLayer, Qwen3Model: QEffQwen3Model, Qwen3ForCausalLM: QEffQwen3ForCausalLM, # Qwen2.5 VL - # Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, - # Qwen2_5_VLModel: QEffQwen2_5_VLModel, + Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, + Qwen2_5_VLModel: QEffQwen2_5_VLModel, Qwen2_5_VLAttention: QEffQwen2_5_VLAttention, Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer, Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention, - # Qwen3vlmoe - Qwen3VLMoeForConditionalGeneration: QEffQwen3VLMoeForConditionalGeneration, - Qwen3VLMoeModel: QEffQwen3VLMoeModel, - Qwen3VLMoeTextAttention: QEffQwen3VLMoeTextAttention, - Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, - Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, - Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, - Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, - Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock, - # Grok1 - # Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, + Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, @@ -627,6 +698,11 @@ class KVCacheTransform(ModuleMappingTransform): GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, + # Olmo2 + Olmo2Attention: QEffOlmo2Attention, + Olmo2DecoderLayer: QEffOlmo2DecoderLayer, + Olmo2Model: QEffOlmo2Model, + Olmo2ForCausalLM: QEffOlmo2ForCausalLM, # Whisper encoder and decoder layers WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, WhisperAttention: QEffWhisperAttention, @@ -704,7 +780,7 @@ class SpDTransform: # Llama QEffLlamaForCausalLM, QEffQwen2ForCausalLM, - # QEffQwen3ForCausalLM, + QEffQwen3ForCausalLM, } @classmethod @@ -770,7 +846,7 @@ class SamplerTransform: QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, - # QEffQwen_2_5_vl_DecoderWrapper, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod @@ -816,6 +892,32 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # Mapping for Molmo + "MolmoForCausalLM": { + "forward": QEffMolmoModel.forward, + "get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder, + "get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder, + "get_specializations": QEffMolmoModel.get_specializations, + "get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes, + "get_output_names": QEffMolmoModel.get_output_names, + "get_dummy_inputs": QEffMolmoModel.get_dummy_inputs, + "get_inputs_info": QEffMolmoModel.get_inputs_info, + }, + "RMSLayerNorm": {"forward": CustomRMSNormAIC.forward}, + # "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward}, + "Molmo": {"forward": QEffMolmo.forward}, + "MolmoSequentialBlock": { + "forward": QEffMolmoSequentialBlock.forward, + "attention": QEffMolmoBlock.attention, + "__qeff_init__": QEffMolmoBlock.__qeff_init__, + }, + "MolmoBlock": { + "attention": QEffMolmoBlock.attention, + "__qeff_init__": QEffMolmoBlock.__qeff_init__, + }, + "MultiHeadDotProductAttention": { + "forward": QEffMultiHeadDotProductAttention.forward, + }, # Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, "Grok1Model": { From 0f0ffa695528b2b4f47c0bedc3c30d3933768ec0 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 27 Feb 2026 07:36:54 +0000 Subject: [PATCH 07/14] Fix for compile retained state and example file upgrade Signed-off-by: Dipankar Sarkar --- QEfficient/transformers/models/modeling_auto.py | 2 +- examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index bbf5b7270..a16da1e92 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1602,9 +1602,9 @@ def compile( specializations = specializations["lang"] self.lang_model._compile( - onnx_path=self.lang_model.onnx_path, compile_dir=compile_dir, compile_only=True, + retained_state=True, specializations=specializations, convert_to_fp16=True, mxfp6_matmul=mxfp6_matmul, diff --git a/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 0e308c1e7..0513619ec 100644 --- a/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -51,8 +51,7 @@ mxfp6_matmul=True, mxint8_kv_cache=True, retain_full_kv=True, - split_retained_state_io=True, - retained_state=True, + # split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, prefill_only=True, @@ -73,8 +72,7 @@ mxfp6_matmul=True, mxint8_kv_cache=True, retain_full_kv=True, - split_retained_state_io=True, - retained_state=True, + # split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, prefill_only=True, @@ -103,7 +101,6 @@ width=536, num_cores=16, num_devices=num_devices, - retained_state=True, mos=1, aic_enable_depth_first=True, # prefill_only=True, @@ -130,8 +127,6 @@ ] -########################### example for inference - messages = [messages] * BS texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] From 70b2cd2ee9034defcc567cf13c26f40cfd6a34f4 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 27 Feb 2026 07:46:51 +0000 Subject: [PATCH 08/14] Rearranging the example files and removing redundant file Signed-off-by: Dipankar Sarkar --- .../models/qwen3vl}/qwen3_vl.py | 6 +- .../models/qwen3vl/qwen3vl_multi.py | 0 examples/qwen3_vl_moe.py | 134 ------------------ 3 files changed, 4 insertions(+), 136 deletions(-) rename examples/{ => image_text_to_text/models/qwen3vl}/qwen3_vl.py (96%) delete mode 100644 examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py delete mode 100644 examples/qwen3_vl_moe.py diff --git a/examples/qwen3_vl.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py similarity index 96% rename from examples/qwen3_vl.py rename to examples/image_text_to_text/models/qwen3vl/qwen3_vl.py index 75bed1630..aa3e34dab 100644 --- a/examples/qwen3_vl.py +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py @@ -15,7 +15,9 @@ model_id = "Qwen/Qwen3-VL-32B-Instruct" config = AutoConfig.from_pretrained(model_id) - +config.vision_config.depth = 9 +config.text_config.num_hidden_layers = 1 +config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config ) @@ -76,7 +78,7 @@ prefill_seq_len=128, ctx_len=4096, num_cores=16, - num_devices=4, + num_devices=2, height=354, width=536, # height=1024, diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py b/examples/image_text_to_text/models/qwen3vl/qwen3vl_multi.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/qwen3_vl_moe.py b/examples/qwen3_vl_moe.py deleted file mode 100644 index 931cfe093..000000000 --- a/examples/qwen3_vl_moe.py +++ /dev/null @@ -1,134 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import requests -import transformers -from PIL import Image -from qwen_vl_utils import process_vision_info -from transformers import AutoConfig, AutoProcessor, TextStreamer - -from QEfficient import QEFFAutoModelForImageTextToText - -model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" -config = AutoConfig.from_pretrained(model_id) - -# For Testing Purpose Only -config.vision_config.depth = 1 -config.text_config.num_hidden_layers = 1 - -qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config -) - -tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) -processor = AutoProcessor.from_pretrained(model_id) -### use skip_vision=Ture, if want to run only text, ow false ### -skip_vision = False - -if skip_vision: - ## Only Text ## - ## Set Batch_Size ## - batch_size = 1 - qeff_model.compile( - batch_size=batch_size, - prefill_seq_len=128, - ctx_len=4096, - num_cores=16, - num_devices=4, - height=354, - width=536, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - mos=1, - ) - - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Tell me about yourself."}, - ], - }, - ] - - messages = [messages] * batch_size - - inputs = processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ) - inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=100) - print(output.generated_ids) - print(tokenizer.batch_decode(output.generated_ids)) - print(output) - -else: - batch_size = 1 - ## Vision + Text ## - qeff_model.compile( - batch_size=batch_size, - prefill_seq_len=128, - ctx_len=4096, - num_cores=16, - num_devices=4, - height=354, - width=536, - mxfp6_matmul=True, - mxint8_kv_cache=True, - aic_enable_depth_first=True, - mos=1, - ) - - ### IMAGE + TEXT ### - image_url = "https://picsum.photos/id/237/536/354" - - image = Image.open(requests.get(image_url, stream=True).raw) - - messages_1 = [ - { - "role": "user", - "content": [ - {"type": "image", "image": image}, - {"type": "text", "text": "Descibe all the colors seen in the image."}, - ], - }, - ] - - # messages_2 = [ - # { - # "role": "user", - # "content": [ - # {"type": "image", "image": image}, - # {"type": "text", "text": "Describe about the color of the dog."}, - # ], - # }, - # ] - - messages = [messages_1] * batch_size - - texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] - - image_inputs, video_inputs = process_vision_info(messages) - inputs = processor( - text=texts, - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", - ) - inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=100) - print(output.generated_ids) - print(tokenizer.batch_decode(output.generated_ids)) - print(output) From 18c4f2c3ac9790fd1258b29f47bfbe3844c16702 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 27 Feb 2026 11:27:00 +0000 Subject: [PATCH 09/14] Fix cb for qwen3vl and qwen3vlmoe with deepstack_features enabled Signed-off-by: Dipankar Sarkar --- QEfficient/generation/embedding_handler.py | 10 +- QEfficient/generation/vlm_generation.py | 34 +++- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 173 +++++++++++------- .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 11 +- .../models}/qwen3_vl_moe/qwen3_vl_moe.py | 3 +- .../qwen3_vl_moe_contnious_batching.py | 3 +- .../models/qwen3vl/qwen3_vl.py | 4 +- .../qwen3vl/qwen3_vl_continous_batching.py | 70 +++++++ 8 files changed, 231 insertions(+), 77 deletions(-) rename examples/{ => image_text_to_text/models}/qwen3_vl_moe/qwen3_vl_disagg_mode.py (97%) rename examples/{ => image_text_to_text/models}/qwen3_vl_moe/qwen3_vl_moe.py (97%) rename examples/{ => image_text_to_text/models}/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py (95%) create mode 100644 examples/image_text_to_text/models/qwen3vl/qwen3_vl_continous_batching.py diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 13fcc6f93..bbccadb70 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -268,6 +268,14 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] ) + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen3_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + # Convert to float32 if needed if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) @@ -418,7 +426,7 @@ def setup_vision_buffers(self): buffers = {} for output_name, shape in shapes.items(): # Create placeholder with appropriate dtype - if "vision_embeds" in output_name: + if "vision_embeds" or "deepstack_features" in output_name: buffers[output_name] = np.zeros(shape, dtype=np.float16) else: buffers[output_name] = np.zeros(shape, dtype=np.float32) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 0df3e6511..05e867644 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -152,6 +152,9 @@ def __init__( self.is_qwen3_vl_moe = ( hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe" ) + self.is_qwen3_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl" + ) self.qeff_model = qeff_model self.processor = processor self.tokenizer = tokenizer @@ -263,6 +266,8 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) elif self.is_qwen3_vl_moe: _ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id) + elif self.is_qwen3_vl: + _ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id) else: _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -308,6 +313,27 @@ def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_le self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id + def update_decode_inputs_qwen3_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + def _execute_chunked_prefill( self, lang_inputs: Dict[str, np.ndarray], @@ -610,6 +636,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) if self.is_qwen3_vl_moe: self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) + if self.is_qwen3_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) # Create prompt queue prompt_queue = deque(vision_prompts) @@ -720,10 +748,14 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation self.update_decode_inputs_qwen2_5_vl( outputs, position_ids_decode, generation_len_final, decode_batch_id ) - if self.is_qwen3_vl_moe: + elif self.is_qwen3_vl_moe: self.update_decode_inputs_qwen3_vl_moe( outputs, position_ids_decode, generation_len_final, decode_batch_id ) + elif self.is_qwen3_vl: + self.update_decode_inputs_qwen3_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) else: self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) else: diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index dbae48c4a..f9f542daa 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -555,6 +555,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + layer_idx = 0 for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -578,13 +579,13 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - layer_idx = 0 - if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]): hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx], ) + layer_idx += 1 hidden_states = self.norm(hidden_states) if output_hidden_states: @@ -600,6 +601,57 @@ def forward( attentions=all_self_attns, ) + return (hidden_states, past_key_values) + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ): + visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + mixed_embeds = hidden_states + visual_embeds + + local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states) + + return local_this + + +class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + act = getattr(self.experts, "act_fn", F.silu) + + router_logits = self.gate(x) # [T, E] + prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] + top_w = top_w / top_w.sum(dim=-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + + # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] + W_up = self.experts.gate_up_proj + W_dn = self.experts.down_proj + E, H_w, twoI = W_up.shape + I2 = twoI // 2 + routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E] + routing_weights.scatter_(1, top_i, top_w) + expert_out = x.new_zeros((T, H)) + for e in range(E): + rw = routing_weights[:, e].unsqueeze(-1) # [T, 1] + # Split fused [H, 2I] -> [H, I] + [H, I] + W_gate_e = W_up[e, :, :I2] + W_up_e = W_up[e, :, I2:] + W_dn_e = W_dn[e, :, :] + gate = x @ W_gate_e + up = x @ W_up_e + down = (up * act(gate)) @ W_dn_e + expert_out.add_(down * rw) + return expert_out.view(B, S, H), router_logits + class QEffQwen3VLMoeModel(Qwen3VLMoeModel): def forward( @@ -659,42 +711,29 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.visual - def get_submodules_for_export(self) -> Type[nn.Module]: - """ - Return the set of class used as the repeated layer across the model for subfunction extraction. - Notes: - This method should return the *class object* (not an instance). - Downstream code can use this to find/build subfunctions for repeated blocks. - """ - return {self.model.visual.blocks[0].__class__} - def forward(self, pixel_values, image_grid_thw): - image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] + image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw) bs = image_grid_thw.shape[0] split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) - return image_embeds + deepstack_features = torch.stack( + [feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists], + dim=0, # new axis for "features" + ) + return image_embeds, deepstack_features class QEffQwen3VLDecoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model - self.language_model = self.model.model.language_model - - def get_submodules_for_export(self) -> Type[nn.Module]: - """ - Return the set of class used as the repeated layer across the model for subfunction extraction. - Notes: - This method should return the *class object* (not an instance). - Downstream code can use this to find/build subfunctions for repeated blocks. - """ - return {QEffQwen3VLDecoderWrapper} + self.language_model = self.model.model def forward( self, input_ids, vision_embeds, + deepstack_features, position_ids, image_idx, past_key_values, @@ -708,21 +747,38 @@ def forward( indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + + num_features, bs, split_size, C = deepstack_features.shape + x = deepstack_features.reshape(num_features, bs * split_size, C) + deepstack_features_expanded = x[:, indices1, :] image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) - outputs = self.model.model( + # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + inputs_embeds = image_input_embeds + + image_mask = selected.clone() + + visual_pos_masks = None + deepstack_visual_embeds = None + + if image_mask is not None: + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_features_expanded + + outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=True, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.model.lm_head(hidden_states) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return logits, vision_embeds, image_idx, outputs.past_key_values + return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values class QEffQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): @@ -732,7 +788,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens x = hidden_states.view(T, H) router_logits = self.gate(x) - prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + prob = F.softmax(router_logits, dim=-1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) top_w = top_w / top_w.sum(dim=1, keepdim=True) top_w = top_w.to(x.dtype) @@ -754,40 +810,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return experts_out, router_logits -class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - B, S, H = hidden_states.shape - T = B * S - x = hidden_states.view(T, H) - act = getattr(self.experts, "act_fn", F.silu) - - router_logits = self.gate(x) # [T, E] - prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) - top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] - top_w = top_w / top_w.sum(dim=-1, keepdim=True) - top_w = top_w.to(hidden_states.dtype) - - # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] - W_up = self.experts.gate_up_proj - W_dn = self.experts.down_proj - E, H_w, twoI = W_up.shape - I2 = twoI // 2 - routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E] - routing_weights.scatter_(1, top_i, top_w) - expert_out = x.new_zeros((T, H)) - for e in range(E): - rw = routing_weights[:, e].unsqueeze(-1) # [T, 1] - # Split fused [H, 2I] -> [H, I] + [H, I] - W_gate_e = W_up[e, :, :I2] - W_up_e = W_up[e, :, I2:] - W_dn_e = W_dn[e, :, :] - gate = x @ W_gate_e - up = x @ W_up_e - down = (up * act(gate)) @ W_dn_e - expert_out.add_(down * rw) - return expert_out.view(B, S, H), router_logits - - class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): def get_qeff_vision_encoder(self): return QEffQwen3VLEncoderWrapper(self) @@ -804,6 +826,7 @@ def get_dummy_inputs( ): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + # vision_size = 1024 vision_size = 187 inputs_shapes["vision_embeds"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -819,6 +842,12 @@ def get_dummy_inputs( inputs_shapes["pixel_values"] = (748, 1536) inputs_shapes["image_idx"] = (1, 1) inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + inputs_shapes["deepstack_features"] = ( + len(self.config.vision_config.deepstack_visual_indexes), + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.vision_config.out_hidden_size, + ) vision_inputs = {} lang_inputs = {} @@ -836,6 +865,7 @@ def get_dummy_inputs( .repeat(4, 1, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32) # Add data for KV bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE @@ -874,6 +904,8 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + time: int = 1, + # dimensions: List = None, num_frames: int = 1, kv_offload: bool = False, continuous_batching: bool = False, @@ -949,8 +981,8 @@ def smart_resize( grid_height = grid_h * grid_w grid_width = patch_size * patch_size * temporal_patch_size * channel vision_size = grid_height // 4 - vision_size = vision_size * num_frames - grid_height = grid_height * batch_size + vision_size = vision_size * num_frames * time + grid_height = grid_height * time * batch_size vision = [ { @@ -958,8 +990,10 @@ def smart_resize( "vision_size": vision_size, "grid_height": grid_height, "grid_width": grid_width, + "time": time, "grid_h": grid_h, "grid_w": grid_w, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), } ] @@ -974,6 +1008,7 @@ def smart_resize( "vision_size": vision_size, "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), } if continuous_batching: @@ -993,6 +1028,7 @@ def smart_resize( "vision_size": vision_size, "comp_ctx_lengths": comp_ctx_lengths_decode[i], "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), } if continuous_batching: @@ -1008,6 +1044,7 @@ def smart_resize( "ctx_len": ctx_len, "vision_size": vision_size, "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), } if continuous_batching: @@ -1023,6 +1060,7 @@ def smart_resize( "ctx_len": ctx_len, "vision_size": vision_size, "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), } if continuous_batching: @@ -1050,13 +1088,15 @@ def get_onnx_dynamic_axes( num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { "pixel_values": {0: "grid_height", 1: "grid_width"}, - "image_grid_thw": {0: "batch_size", 2: "grid_h", 3: "grid_w"}, + "image_grid_thw": {0: "batch_size", 1: "time", 2: "grid_h", 3: "grid_w"}, + "deepstack_features": {0: "num_feature_layers", 1: "batch_size", 2: "vision_size"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {1: "batch_size", 2: "seq_len"}, "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, + "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"}, } for i in range(num_layers): @@ -1087,6 +1127,7 @@ def get_onnx_dynamic_axes( def get_output_names(self, kv_offload: bool = False): vision_output_names = ["vision_embeds"] + vision_output_names.append("deepstack_features") lang_output_names = ["logits"] for i in range(self.model.config.text_config.num_hidden_layers): for kv in ["key", "value"]: @@ -1096,6 +1137,7 @@ def get_output_names(self, kv_offload: bool = False): if kv_offload: lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") + lang_output_names.insert(2, "deepstack_features_RetainedState") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: @@ -1122,7 +1164,6 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size= inputs["position_ids"] = F.pad( inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 ) - inputs.pop("image_grid_thw", None) return inputs def get_inputs_info(self): diff --git a/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py similarity index 97% rename from examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py rename to examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 0513619ec..088026ce6 100644 --- a/examples/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -23,9 +23,9 @@ # TODO clean up this script # For Testing Purpose Only -# config.vision_config.depth = 1 -# config.text_config.num_hidden_layers = 1 -num_devices = 4 +config.vision_config.depth = 1 +config.text_config.num_hidden_layers = 1 +num_devices = 1 qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config @@ -212,11 +212,10 @@ prefill_time = perf_counter() - lang_start + vision_end - vision_start print(f"Prefill time :{prefill_time:.2f} secs") - all_outputs.append(np.argmax(outputs["logits"])) decode_inputs = { "input_ids": np.argmax(outputs["logits"]).reshape(1, 1), - "position_ids": np.max(lang_inputs["position_ids"]).reshape(1, 1) + 1, + "position_ids": np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1, } for i in range(config.text_config.num_hidden_layers): @@ -230,7 +229,7 @@ print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") all_outputs.append(np.argmax(decode_out["logits"])) -pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +pos_id = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 loop_decode_inputs = { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), "position_ids": pos_id, diff --git a/examples/qwen3_vl_moe/qwen3_vl_moe.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py similarity index 97% rename from examples/qwen3_vl_moe/qwen3_vl_moe.py rename to examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py index 931cfe093..1f7cd5a06 100644 --- a/examples/qwen3_vl_moe/qwen3_vl_moe.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py @@ -17,8 +17,9 @@ config = AutoConfig.from_pretrained(model_id) # For Testing Purpose Only -config.vision_config.depth = 1 +config.vision_config.depth = 9 config.text_config.num_hidden_layers = 1 +config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config diff --git a/examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py similarity index 95% rename from examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py rename to examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py index f209ad87b..b134dd6fd 100644 --- a/examples/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py @@ -12,8 +12,9 @@ model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" config = AutoConfig.from_pretrained(model_id) -config.vision_config.depth = 1 +config.vision_config.depth = 9 config.text_config.num_hidden_layers = 1 +config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py index aa3e34dab..a71b5a65c 100644 --- a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py @@ -15,9 +15,11 @@ model_id = "Qwen/Qwen3-VL-32B-Instruct" config = AutoConfig.from_pretrained(model_id) + config.vision_config.depth = 9 config.text_config.num_hidden_layers = 1 config.vision_config.deepstack_visual_indexes = [8] + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config ) @@ -78,7 +80,7 @@ prefill_seq_len=128, ctx_len=4096, num_cores=16, - num_devices=2, + num_devices=4, height=354, width=536, # height=1024, diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3_vl_continous_batching.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl_continous_batching.py new file mode 100644 index 000000000..dcf9ae001 --- /dev/null +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl_continous_batching.py @@ -0,0 +1,70 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "Qwen/Qwen3-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.vision_config.depth = 9 +config.text_config.num_hidden_layers = 1 +config.vision_config.deepstack_visual_indexes = [8] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) From f1a35f0f6a3fd85813340e9c20697a57a3326049 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Fri, 27 Feb 2026 11:30:52 +0000 Subject: [PATCH 10/14] Cleaning changes 1 Signed-off-by: Dipankar Sarkar --- .../transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 630965790..d6bfbda81 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -591,7 +591,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - # kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 From 932899149b30142ed41f9fbd304fa97b35cdfc8a Mon Sep 17 00:00:00 2001 From: Karthikeya Date: Thu, 5 Mar 2026 15:54:56 +0530 Subject: [PATCH 11/14] updating Subfn, Prefill only logic in Disagg mode (#820) Added Support for Subfn for Qwen 3 VL dense, MOE. Updated prefill only logic for disagg mode --------- Signed-off-by: vtirumal Signed-off-by: Dipankar Sarkar --- .../transformers/models/modeling_auto.py | 68 ++++----- .../models/qwen3_vl/modeling_qwen3_vl.py | 20 ++- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 31 +++- .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 132 ++++++++---------- .../models/qwen3_vl_moe/qwen3_vl_moe.py | 12 +- ....py => qwen3_vl_moe_continous_batching.py} | 8 +- .../models/qwen3vl/qwen3_vl.py | 2 + 7 files changed, 142 insertions(+), 131 deletions(-) rename examples/image_text_to_text/models/qwen3_vl_moe/{qwen3_vl_moe_contnious_batching.py => qwen3_vl_moe_continous_batching.py} (89%) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a16da1e92..11dea4ec5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1090,16 +1090,16 @@ def export( Path to the generated ONNX graph file for the language decoder. """ if prefill_only: - if prefill_seq_len > 1: - if not enable_chunking and self.continuous_batching: - raise NotImplementedError( - "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" - ) - self.hash_params["prefill_only"] = True - self.prefill(enable=True, enable_chunking=enable_chunking) - else: - self.hash_params["prefill_only"] = False - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + assert prefill_seq_len > 1 + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.hash_params["prefill_only"] = True + self.prefill(enable=True, enable_chunking=enable_chunking) + else: + self.hash_params["prefill_only"] = False + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) return self._export( inputs, @@ -1286,24 +1286,6 @@ def onnx_path(self): """ return [self.vision_model.onnx_path, self.lang_model.onnx_path] - @property - def qpc_path(self): - """ - Get the QPC paths for the vision and language model components. - - Returns - ------- - Union[List[str], str, None] - A list containing both QPC paths if both are compiled, or just one if only one is, - or None if neither is compiled. - """ - if self.vision_model.qpc_path and self.lang_model.qpc_path: - return [self.vision_model.qpc_path, self.lang_model.qpc_path] - elif self.vision_model.qpc_path: - return self.vision_model.qpc_path - else: - return self.lang_model.qpc_path - def export( self, export_dir: Optional[str] = None, @@ -1416,7 +1398,7 @@ def compile( skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, use_onnx_subfunctions: bool = False, - prefill_only=False, + prefill_only=None, enable_chunking=False, **compiler_options, ) -> str: @@ -1536,11 +1518,7 @@ def compile( if lang_onnx_path: self.lang_model.onnx_path = lang_onnx_path - if ( - (self.vision_model.onnx_path is None and vision_onnx_path is None) - or (self.lang_model.onnx_path is None and lang_onnx_path is None) - or prefill_only - ): + if vision_onnx_path is None or lang_onnx_path is None: self.export( use_onnx_subfunctions=use_onnx_subfunctions, skip_vision=skip_vision, @@ -1554,8 +1532,9 @@ def compile( compiler_options.pop("continuous_batching", None) compiler_options.pop("kv_cache_batch_size", None) compiler_options.pop("full_batch_size", None) + self.qpc_paths = {} if not skip_vision: - self.vision_model._compile( + vision_qpc_path = self.vision_model._compile( compile_dir=compile_dir, compile_only=True, specializations=specializations["vision"], @@ -1568,6 +1547,7 @@ def compile( use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) + self.qpc_paths["vision_qpc_path"] = vision_qpc_path # Custom NPI file options if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: @@ -1592,16 +1572,17 @@ def compile( if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) - if prefill_only: - if prefill_seq_len > 1: - specializations = specializations["lang"][:1] # prefill - else: - specializations = specializations["lang"][-1:] # decoder + specializations = specializations["lang"][:1] + qpc_key = "lang_prefill_qpc_path" + elif prefill_seq_len == 1: + specializations = specializations["lang"][-1:] + qpc_key = "lang_decode_qpc_path" else: specializations = specializations["lang"] + qpc_key = "lang_qpc_path" - self.lang_model._compile( + lang_qpc_path = self.lang_model._compile( compile_dir=compile_dir, compile_only=True, retained_state=True, @@ -1615,9 +1596,8 @@ def compile( use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) - if skip_vision and prefill_only: # for disagg serving - return self.lang_model.qpc_path - return self.qpc_path + self.qpc_paths.update({qpc_key: lang_qpc_path}) + return self.qpc_paths def generate( self, diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index cc12e6f39..070856c6e 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -562,6 +562,15 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.visual + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.visual.blocks[0].__class__} + def forward(self, pixel_values, image_grid_thw): image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw) bs = image_grid_thw.shape[0] @@ -580,6 +589,15 @@ def __init__(self, model): self.model = model self.language_model = self.model.model + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwen3VLTextDecoderLayer} + def forward( self, input_ids, diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index f9f542daa..076e4cdbd 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -629,7 +629,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] - top_w = top_w / top_w.sum(dim=-1, keepdim=True) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(hidden_states.dtype) # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] @@ -711,6 +711,15 @@ def __init__(self, model): self.model = model self.model.vision_model = self.model.visual + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {self.model.visual.blocks[0].__class__} + def forward(self, pixel_values, image_grid_thw): image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw) bs = image_grid_thw.shape[0] @@ -727,7 +736,16 @@ class QEffQwen3VLDecoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model - self.language_model = self.model.model + self.language_model = self.model.model.language_model + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwen3VLMoeTextDecoderLayer} def forward( self, @@ -790,7 +808,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens router_logits = self.gate(x) prob = F.softmax(router_logits, dim=-1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) - top_w = top_w / top_w.sum(dim=1, keepdim=True) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) w_up = self.experts.gate_up_proj.index_select(0, idx) @@ -805,9 +823,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens intermediate = up * self.experts.act_fn(gate) experts_out = torch.bmm(intermediate, w_dn) experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1) - experts_out = experts_out.sum(dim=1).view(B, S, H) - - return experts_out, router_logits + experts_out = torch.einsum("bnd->bd", experts_out) + return experts_out.view(B, S, H), router_logits class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 088026ce6..897eea350 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -21,11 +21,10 @@ model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" config = AutoConfig.from_pretrained(model_id) -# TODO clean up this script -# For Testing Purpose Only -config.vision_config.depth = 1 -config.text_config.num_hidden_layers = 1 -num_devices = 1 +# For faster execution user can run with lesser layers, For Testing Purpose Only +# config.vision_config.depth = 9 +# config.text_config.num_hidden_layers = 1 +# config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config @@ -36,28 +35,40 @@ PREFILL_SEQ_LEN = 128 CTX_LEN = 4096 BS = 1 -torch.manual_seed(0) -skip_vision = True +skip_vision = False +if not skip_vision: + vision_qpc_path = qeff_model.compile( + batch_size=BS, + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + height=354, + width=536, + num_cores=16, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + skip_vision=skip_vision, + skip_lang=True, + use_onnx_subfunctions=True, + ) prefill_qpc_path = qeff_model.compile( batch_size=BS, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, - height=354, - width=536, num_cores=16, - num_devices=num_devices, + num_devices=1, mxfp6_matmul=True, mxint8_kv_cache=True, retain_full_kv=True, - # split_retained_state_io=True, # This should be used for disagg serving via VLLM + split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, prefill_only=True, enable_chunking=True, skip_vision=True, - use_onnx_subfunctions=False, + use_onnx_subfunctions=True, ) @@ -65,25 +76,23 @@ batch_size=BS, prefill_seq_len=1, ctx_len=CTX_LEN, - height=354, - width=536, num_cores=16, - num_devices=num_devices, + num_devices=1, mxfp6_matmul=True, mxint8_kv_cache=True, retain_full_kv=True, - # split_retained_state_io=True, # This should be used for disagg serving via VLLM + split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, - prefill_only=True, - enable_chunking=True, + prefill_only=False, skip_vision=True, - use_onnx_subfunctions=False, + use_onnx_subfunctions=True, ) -if skip_vision: # for only LLM with DA - lang_prefill_session = QAICInferenceSession(prefill_qpc_path) - lang_decode_session = QAICInferenceSession(decode_qpc_path) +lang_prefill_session = QAICInferenceSession(prefill_qpc_path.get("lang_prefill_qpc_path")) +lang_decode_session = QAICInferenceSession(decode_qpc_path.get("lang_decode_qpc_path")) + +if skip_vision: messages = [ { "role": "user", @@ -93,25 +102,6 @@ }, ] else: - vision_qpc_path = qeff_model.compile( - batch_size=BS, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - height=354, - width=536, - num_cores=16, - num_devices=num_devices, - mos=1, - aic_enable_depth_first=True, - # prefill_only=True, - # enable_chunking=True, - skip_vision=skip_vision, - skip_lang=True, - use_onnx_subfunctions=False, - ) - vision_session = QAICInferenceSession(vision_qpc_path) - lang_prefill_session = QAICInferenceSession(prefill_qpc_path) - lang_decode_session = QAICInferenceSession(decode_qpc_path) ### IMAGE + TEXT ### image_url = "https://picsum.photos/id/237/536/354" image = Image.open(requests.get(image_url, stream=True).raw) @@ -173,15 +163,27 @@ vision_inputs_fp16 = {"pixel_values", "image_masks"} vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) +if not skip_vision: + vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path")) + vision_start = perf_counter() vision_outputs = {} if vision_inputs: vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() -# TODO : pass vision_embeds_RetainedState to prefill -# vision_outputs["vision_embeds_RetainedState"] -# *** KeyError: 'vision_embeds_RetainedState' +if not skip_vision: + vision_session.deactivate() + +lang_prefill_session.activate() +# Skip inputs/outputs +lang_prefill_session.skip_buffers( + [ + x + for x in lang_prefill_session.input_names + lang_prefill_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] +) lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} if "position_ids" in inputs: @@ -197,20 +199,28 @@ # RUN prefill lang_start = perf_counter() - all_outputs = [] +chunk_inputs = lang_inputs.copy() for i in range(num_chunks): - chunk_inputs = lang_inputs.copy() chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] chunk_inputs["position_ids"] = lang_inputs["position_ids"][..., i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] outputs = lang_prefill_session.run(chunk_inputs) - for i in range(config.text_config.num_hidden_layers): - lang_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] - lang_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] - chunk_inputs["image_idx"] = outputs["image_idx_output"] prefill_time = perf_counter() - lang_start + vision_end - vision_start -print(f"Prefill time :{prefill_time:.2f} secs") +print(f"Prefill time : {prefill_time:.2f} secs") + +lang_prefill_session.deactivate() +lang_decode_session.activate() +# Skip inputs/outputs +lang_decode_session.skip_buffers( + [ + x + for x in lang_decode_session.input_names + lang_decode_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] +) + +lang_decode_session.set_buffers(outputs) all_outputs.append(np.argmax(outputs["logits"])) decode_inputs = { @@ -218,12 +228,6 @@ "position_ids": np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1, } -for i in range(config.text_config.num_hidden_layers): - decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] - decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] - decode_inputs["vision_embeds_RetainedState"] = outputs["vision_embeds_RetainedState"] - decode_inputs["image_idx_output"] = outputs["image_idx_output"] - st = perf_counter() decode_out = lang_decode_session.run(decode_inputs) print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") @@ -235,25 +239,11 @@ "position_ids": pos_id, } - -for i in range(config.text_config.num_hidden_layers): - loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] - loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - loop_decode_inputs["vision_embeds_RetainedState"] = decode_out["vision_embeds_RetainedState"] - loop_decode_inputs["image_idx_output"] = decode_out["image_idx_output"] - - st = perf_counter() for i in range(generation_len - 2): decode_out = lang_decode_session.run(loop_decode_inputs) all_outputs.append(np.argmax(decode_out["logits"])) pos_id += 1 - for j in range(config.text_config.num_hidden_layers): - loop_decode_inputs[f"past_key.{j}"] = decode_out[f"past_key.{j}_RetainedState"] - loop_decode_inputs[f"past_value.{j}"] = decode_out[f"past_value.{j}_RetainedState"] - loop_decode_inputs["vision_embeds_RetainedState"] = decode_out["vision_embeds_RetainedState"] - loop_decode_inputs["image_idx_output"] = decode_out["image_idx_output"] - loop_decode_inputs.update( { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py index 1f7cd5a06..94ba1564e 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py @@ -16,10 +16,10 @@ model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" config = AutoConfig.from_pretrained(model_id) -# For Testing Purpose Only -config.vision_config.depth = 9 -config.text_config.num_hidden_layers = 1 -config.vision_config.deepstack_visual_indexes = [8] +# For faster execution user can run with lesser layers, For Testing Purpose Only +# config.vision_config.depth = 9 +# config.text_config.num_hidden_layers = 1 +# config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config @@ -27,7 +27,7 @@ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) -### use skip_vision=Ture, if want to run only text, ow false ### +### use skip_vision=Ture, if want to run only text, or false ### skip_vision = False if skip_vision: @@ -46,6 +46,7 @@ aic_enable_depth_first=True, skip_vision=True, mos=1, + use_onnx_subfunctions=True, ) messages = [ @@ -88,6 +89,7 @@ mxint8_kv_cache=True, aic_enable_depth_first=True, mos=1, + use_onnx_subfunctions=True, ) ### IMAGE + TEXT ### diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_continous_batching.py similarity index 89% rename from examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py rename to examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_continous_batching.py index b134dd6fd..9391aeaee 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_contnious_batching.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_continous_batching.py @@ -12,9 +12,11 @@ model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" config = AutoConfig.from_pretrained(model_id) -config.vision_config.depth = 9 -config.text_config.num_hidden_layers = 1 -config.vision_config.deepstack_visual_indexes = [8] + +# For faster execution user can run with lesser layers, For Testing Purpose Only +# config.vision_config.depth = 9 +# config.text_config.num_hidden_layers = 1 +# config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py index a71b5a65c..20badbfba 100644 --- a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py @@ -45,6 +45,7 @@ aic_enable_depth_first=True, skip_vision=True, mos=1, + use_onnx_subfunctions=False, ) messages = [ @@ -89,6 +90,7 @@ mxint8_kv_cache=True, aic_enable_depth_first=True, mos=1, + use_onnx_subfunctions=False, ) ### IMAGE + TEXT ### From 450c8d635c1eef4924272256d0183e8fbdea70df Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Tue, 10 Mar 2026 12:03:14 +0000 Subject: [PATCH 12/14] Addressing Review Comments 2 Signed-off-by: Dipankar Sarkar --- QEfficient/generation/embedding_handler.py | 27 ++------ QEfficient/generation/vlm_generation.py | 79 ++-------------------- QEfficient/transformers/cache_utils.py | 18 +---- 3 files changed, 12 insertions(+), 112 deletions(-) diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index bbccadb70..5a965afb8 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -252,29 +252,10 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - # Process image and text inputs = self._processor(images=image, text=prompt, return_tensors="pt") - if ( - hasattr(self._qeff_model.model.config, "model_type") - and self._qeff_model.model.config.model_type == "qwen2_5_vl" - ): - inputs = self._qeff_model.model.prepare_inputs_for_generation( - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] - ) - - if ( - hasattr(self._qeff_model.model.config, "model_type") - and self._qeff_model.model.config.model_type == "qwen3_vl_moe" - ): + if (hasattr(self._qeff_model.model.config, "model_type")and self._qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"}): inputs = self._qeff_model.model.prepare_inputs_for_generation( - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] - ) - - if ( - hasattr(self._qeff_model.model.config, "model_type") - and self._qeff_model.model.config.model_type == "qwen3_vl" - ): - inputs = self._qeff_model.model.prepare_inputs_for_generation( - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] - ) + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) # Convert to float32 if needed if "pixel_values" in inputs: @@ -426,7 +407,7 @@ def setup_vision_buffers(self): buffers = {} for output_name, shape in shapes.items(): # Create placeholder with appropriate dtype - if "vision_embeds" or "deepstack_features" in output_name: + if "vision_embeds" in output_name or "deepstack_features" in output_name: buffers[output_name] = np.zeros(shape, dtype=np.float16) else: buffers[output_name] = np.zeros(shape, dtype=np.float32) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 05e867644..9d47688d7 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -146,15 +146,7 @@ def __init__( ) # Vision-specific initialization - self.is_qwen2_5_vl = ( - hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" - ) - self.is_qwen3_vl_moe = ( - hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe" - ) - self.is_qwen3_vl = ( - hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl" - ) + self.is_qwen_vl = (hasattr(qeff_model.model.config, "model_type")and qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"}) self.qeff_model = qeff_model self.processor = processor self.tokenizer = tokenizer @@ -262,37 +254,12 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): outputs, position_ids, generation_len = self.run_prefill( next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) ) - if self.is_qwen2_5_vl: - _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) - elif self.is_qwen3_vl_moe: - _ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id) - elif self.is_qwen3_vl: - _ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id) + if self.is_qwen_vl: + _ = self.update_decode_inputs_qwen_vl(outputs, position_ids, generation_len, decode_batch_id) else: _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) - def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): - """ - Updates the decode input with the generated values. - Args: - outputs (dict): The outputs of the model. - position_ids (array): The position IDs. - generation_len (int): The generation length. - decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. - - Returns: - next_token_id (array): The next token ID. - """ - next_token_id = self._fetch_next_token_id(outputs) - - # Store the generated values. - self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id - self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) - self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) - self.generation_len[decode_batch_id or slice(None)] = generation_len - return next_token_id - - def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_len, decode_batch_id=None): + def update_decode_inputs_qwen_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): """ Updates the decode input with the generated values. Args: @@ -313,26 +280,6 @@ def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_le self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def update_decode_inputs_qwen3_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): - """ - Updates the decode input with the generated values. - Args: - outputs (dict): The outputs of the model. - position_ids (array): The position IDs. - generation_len (int): The generation length. - decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. - - Returns: - next_token_id (array): The next token ID. - """ - next_token_id = self._fetch_next_token_id(outputs) - - # Store the generated values. - self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id - self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) - self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) - self.generation_len[decode_batch_id or slice(None)] = generation_len - return next_token_id def _execute_chunked_prefill( self, @@ -632,11 +579,7 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream, max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) - if self.is_qwen2_5_vl: - self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) - if self.is_qwen3_vl_moe: - self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) - if self.is_qwen3_vl: + if self.is_qwen_vl: self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) # Create prompt queue prompt_queue = deque(vision_prompts) @@ -744,16 +687,8 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) # Update decode inputs - if self.is_qwen2_5_vl: - self.update_decode_inputs_qwen2_5_vl( - outputs, position_ids_decode, generation_len_final, decode_batch_id - ) - elif self.is_qwen3_vl_moe: - self.update_decode_inputs_qwen3_vl_moe( - outputs, position_ids_decode, generation_len_final, decode_batch_id - ) - elif self.is_qwen3_vl: - self.update_decode_inputs_qwen3_vl( + if self.is_qwen_vl: + self.update_decode_inputs_qwen_vl( outputs, position_ids_decode, generation_len_final, decode_batch_id ) else: diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 37215702a..3dc7ce514 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -192,7 +192,6 @@ def update( A tuple containing the updated key and value states. """ # Update the cache - # if not self.is_initialized: if self.keys is None: self.keys = key_states @@ -336,8 +335,6 @@ def __init__( layer_class_to_replicate=QEffDynamicLayer, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding, - # args=args, - # kwargs=kwargs, ) else: Cache.__init__( @@ -345,8 +342,6 @@ def __init__( layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding, - # args=args, - # kwargs=kwargs, ) if ddp_cache_data is not None: @@ -434,18 +429,7 @@ def update3D( self.append_new_layers(layer_idx) return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs) - # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - # """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # # TODO: deprecate this function in favor of `cache_position` - # breakpoint() - # is_empty_layer = ( - # len(self.key_cache) == 0 # no cache in any layer - # or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - # or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - # ) - # layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - # return layer_seq_length - + class QEffEncoderDecoderCache(EncoderDecoderCache): """ From 47dd7483ce74fc108d5420db1f6602b6f9915365 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Tue, 10 Mar 2026 18:16:04 +0000 Subject: [PATCH 13/14] Fix for review comments 2 Signed-off-by: Dipankar Sarkar --- QEfficient/generation/embedding_handler.py | 10 +++++++--- QEfficient/generation/vlm_generation.py | 7 +++++-- QEfficient/transformers/cache_utils.py | 1 - 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py index 5a965afb8..8ac2e1e58 100644 --- a/QEfficient/generation/embedding_handler.py +++ b/QEfficient/generation/embedding_handler.py @@ -252,10 +252,14 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - # Process image and text inputs = self._processor(images=image, text=prompt, return_tensors="pt") - if (hasattr(self._qeff_model.model.config, "model_type")and self._qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"}): + if hasattr(self._qeff_model.model.config, "model_type") and self._qeff_model.model.config.model_type in { + "qwen2_5_vl", + "qwen3_vl_moe", + "qwen3_vl", + }: inputs = self._qeff_model.model.prepare_inputs_for_generation( - inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] - ) + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) # Convert to float32 if needed if "pixel_values" in inputs: diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 9d47688d7..892fc145c 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -146,7 +146,11 @@ def __init__( ) # Vision-specific initialization - self.is_qwen_vl = (hasattr(qeff_model.model.config, "model_type")and qeff_model.model.config.model_type in {"qwen2_5_vl", "qwen3_vl_moe", "qwen3_vl"}) + self.is_qwen_vl = hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type in { + "qwen2_5_vl", + "qwen3_vl_moe", + "qwen3_vl", + } self.qeff_model = qeff_model self.processor = processor self.tokenizer = tokenizer @@ -280,7 +284,6 @@ def update_decode_inputs_qwen_vl(self, outputs, position_ids, generation_len, de self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def _execute_chunked_prefill( self, lang_inputs: Dict[str, np.ndarray], diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 3dc7ce514..2b0d50849 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -429,7 +429,6 @@ def update3D( self.append_new_layers(layer_idx) return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs) - class QEffEncoderDecoderCache(EncoderDecoderCache): """ From 615340e6e7da679236375c2feea6b610d17d44bb Mon Sep 17 00:00:00 2001 From: Karthikeya Date: Wed, 11 Mar 2026 15:15:31 +0530 Subject: [PATCH 14/14] Updated disagg example script, Moe changes to resolve issue wrt vision (#849) Signed-off-by: vtirumal --- .../transformers/models/modeling_auto.py | 12 ++-- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 43 ++++++------ .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 69 ++++++++++--------- 3 files changed, 65 insertions(+), 59 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 11dea4ec5..654b3862d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1036,7 +1036,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def prefill( + def __update_prefill_transform( self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False, @@ -1096,10 +1096,10 @@ def export( "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" ) self.hash_params["prefill_only"] = True - self.prefill(enable=True, enable_chunking=enable_chunking) + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) else: self.hash_params["prefill_only"] = False - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) return self._export( inputs, @@ -2699,7 +2699,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def prefill( + def __update_prefill_transform( self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False, @@ -2997,7 +2997,7 @@ def export( raise NotImplementedError( "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" ) - self.prefill(enable=True, enable_chunking=enable_chunking) + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) self.hash_params.pop("retain_full_kv", None) seq_len = self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking @@ -3008,7 +3008,7 @@ def export( else seq_len ) else: - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) self.hash_params.pop("NUM_Q_BLOCKS", None) self.hash_params.pop("NUM_FFN_BLOCKS", None) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 076e4cdbd..67c7daf5e 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -628,29 +628,32 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) - top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k] + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(hidden_states.dtype) - - # gate_up_proj: [E, H, 2I], down_proj: [E, I, H] - W_up = self.experts.gate_up_proj - W_dn = self.experts.down_proj - E, H_w, twoI = W_up.shape - I2 = twoI // 2 - routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E] + routing_weights = torch.zeros((T, self.num_experts), dtype=x.dtype) routing_weights.scatter_(1, top_i, top_w) - expert_out = x.new_zeros((T, H)) - for e in range(E): - rw = routing_weights[:, e].unsqueeze(-1) # [T, 1] - # Split fused [H, 2I] -> [H, I] + [H, I] - W_gate_e = W_up[e, :, :I2] - W_up_e = W_up[e, :, I2:] - W_dn_e = W_dn[e, :, :] - gate = x @ W_gate_e - up = x @ W_up_e - down = (up * act(gate)) @ W_dn_e - expert_out.add_(down * rw) - return expert_out.view(B, S, H), router_logits + + expert_out = torch.zeros_like(x, dtype=x.dtype) + + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + + W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] + W_dn_e = self.experts.down_proj[e] # [I, H] + gate_up = x @ W_gate_up_e # [T, 2I] + + I2 = gate_up.shape[-1] // 2 + gate = gate_up[:, :I2] # [T, I] + up = gate_up[:, I2:] # [T, I] + intermediate = up * act(gate) + down = intermediate @ W_dn_e + masked_down = torch.where( + routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) + ) # TODO: verify and remove + expert_out += masked_down + expert_out = expert_out.to(x.dtype).view(B, S, H) + return expert_out, router_logits class QEffQwen3VLMoeModel(Qwen3VLMoeModel): diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 897eea350..6e3c43951 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -23,7 +23,7 @@ # For faster execution user can run with lesser layers, For Testing Purpose Only # config.vision_config.depth = 9 -# config.text_config.num_hidden_layers = 1 +# config.text_config.num_hidden_layers = 6 # config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( @@ -47,8 +47,10 @@ num_cores=16, num_devices=1, mos=1, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=skip_vision, + split_retained_state_io=True, skip_lang=True, use_onnx_subfunctions=True, ) @@ -57,6 +59,8 @@ batch_size=BS, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, + height=354, + width=536, num_cores=16, num_devices=1, mxfp6_matmul=True, @@ -76,6 +80,8 @@ batch_size=BS, prefill_seq_len=1, ctx_len=CTX_LEN, + height=354, + width=536, num_cores=16, num_devices=1, mxfp6_matmul=True, @@ -111,10 +117,11 @@ "role": "user", "content": [ {"type": "image", "image": image}, - {"type": "text", "text": "Descibe all the colors seen in the image."}, + {"type": "text", "text": "Describe all the colors seen in the image."}, ], }, ] + vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path")) messages = [messages] * BS @@ -163,28 +170,12 @@ vision_inputs_fp16 = {"pixel_values", "image_masks"} vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) -if not skip_vision: - vision_session = QAICInferenceSession(vision_qpc_path.get("vision_qpc_path")) - vision_start = perf_counter() vision_outputs = {} if vision_inputs: vision_outputs = vision_session.run(vision_inputs) vision_end = perf_counter() -if not skip_vision: - vision_session.deactivate() - -lang_prefill_session.activate() -# Skip inputs/outputs -lang_prefill_session.skip_buffers( - [ - x - for x in lang_prefill_session.input_names + lang_prefill_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] -) - lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} if "position_ids" in inputs: lang_inputs["position_ids"] = inputs["position_ids"] @@ -196,54 +187,67 @@ lang_inputs["image_idx"] = np.array([[0]]) +if not skip_vision: + lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] + lang_inputs["deepstack_features"] = vision_outputs["deepstack_features"] # RUN prefill lang_start = perf_counter() +lang_prefill_session.set_buffers(vision_outputs) all_outputs = [] chunk_inputs = lang_inputs.copy() for i in range(num_chunks): chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] chunk_inputs["position_ids"] = lang_inputs["position_ids"][..., i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] outputs = lang_prefill_session.run(chunk_inputs) + for i in range(config.text_config.num_hidden_layers): + chunk_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + chunk_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] chunk_inputs["image_idx"] = outputs["image_idx_output"] prefill_time = perf_counter() - lang_start + vision_end - vision_start print(f"Prefill time : {prefill_time:.2f} secs") -lang_prefill_session.deactivate() -lang_decode_session.activate() -# Skip inputs/outputs -lang_decode_session.skip_buffers( - [ - x - for x in lang_decode_session.input_names + lang_decode_session.output_names - if x.startswith("past_") or x.endswith("_RetainedState") - ] -) - -lang_decode_session.set_buffers(outputs) - all_outputs.append(np.argmax(outputs["logits"])) decode_inputs = { "input_ids": np.argmax(outputs["logits"]).reshape(1, 1), "position_ids": np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1, } +for i in range(config.text_config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] + +decode_inputs["image_idx"] = outputs["image_idx_output"] +decode_inputs["vision_embeds"] = outputs["vision_embeds_RetainedState"] +decode_inputs["deepstack_features"] = outputs["deepstack_features_RetainedState"] + st = perf_counter() decode_out = lang_decode_session.run(decode_inputs) print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") all_outputs.append(np.argmax(decode_out["logits"])) -pos_id = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 +pos_id = np.max(decode_inputs["position_ids"], axis=-1, keepdims=True) + 1 loop_decode_inputs = { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), "position_ids": pos_id, } +for i in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] +loop_decode_inputs["image_idx"] = decode_out["image_idx_output"] +loop_decode_inputs["vision_embeds"] = decode_out["vision_embeds_RetainedState"] +loop_decode_inputs["deepstack_features"] = decode_out["deepstack_features_RetainedState"] + + st = perf_counter() for i in range(generation_len - 2): decode_out = lang_decode_session.run(loop_decode_inputs) all_outputs.append(np.argmax(decode_out["logits"])) pos_id += 1 + for j in range(config.text_config.num_hidden_layers): + loop_decode_inputs[f"past_key.{j}"] = decode_out[f"past_key.{j}_RetainedState"] + loop_decode_inputs[f"past_value.{j}"] = decode_out[f"past_value.{j}_RetainedState"] loop_decode_inputs.update( { "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), @@ -251,6 +255,5 @@ } ) ft = perf_counter() - print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") print(f"\noutput\n{tokenizer.decode(all_outputs)}")