diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ef7e83adf..0378a800f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -261,7 +261,12 @@ def _export( ) apply_torch_patches() InvalidIndexProvider.SUBFUNC_ENABLED = True - output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names] + output_names = [ + re.sub("_RetainedState", "_InternalRetainedState", name) + if name.endswith("_RetainedState") and ("key" in name or "value" in name) + else name + for name in output_names + ] export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model) self._onnx_transforms.append(RenameFunctionOutputsTransform) self._onnx_transforms.append(CustomOpTransform) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 21a867eb5..e64634b62 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -832,23 +832,26 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set: Dynamically determine which DecoderLayer classes should be exported as functions based on the model's architecture using the existing KVCacheTransform mapping. """ - # Define patterns that identify decoder layer classes - DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] - # Get all QEff classes that are decoder layers from the existing mapping + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] decoder_layer_classes = set() for original_class, qeff_class in KVCacheTransform._module_mapping.items(): - # Check if the QEff class name contains decoder layer patterns qeff_class_name = qeff_class.__name__ if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): decoder_layer_classes.add(qeff_class) - # Filter to only include classes that are actually used in the current model model_decoder_classes = set() - for module in model.modules(): - if module.__class__ in decoder_layer_classes: - model_decoder_classes.add(module.__class__) + model_class_name = model.__class__.__name__ + if "EncoderWrapper" in model_class_name: + model_decoder_classes.update( + module.__class__ for module in model.modules() if "Qwen2_5_VLVisionBlock" in module.__class__.__name__ + ) + return model_decoder_classes + + model_decoder_classes.update( + module.__class__ for module in model.modules() if module.__class__ in decoder_layer_classes + ) return model_decoder_classes diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 63e046600..8b8971e59 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -73,14 +73,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - - mrope_section = mrope_section * 2 cos = cos[position_ids] sin = sin[position_ids] - - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) - + cos = torch.cat([cos[0, ..., 0:32], cos[0, ..., 32:80], cos[0, ..., 80:128]], dim=-1).unsqueeze(0) + sin = torch.cat([sin[0, ..., 0:32], sin[0, ..., 32:80], sin[0, ..., 80:128]], dim=-1).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) diff --git a/QEfficient/utils/torch_patches.py b/QEfficient/utils/torch_patches.py index 0b9b37afa..241b32fbf 100644 --- a/QEfficient/utils/torch_patches.py +++ b/QEfficient/utils/torch_patches.py @@ -7,6 +7,8 @@ """Monkey patches for torch.onnx.utils to fix ONNX export issues.""" +import warnings + import torch import torch.onnx.utils as onnx_utils from torch import _C @@ -37,9 +39,13 @@ def _track_module_attributes_forward_hook(module, input, output): if hasattr(module, attr_name): onnx_attrs = getattr(module, attr_name) delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch - onnx_attrs = {} - _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + # onnx_attrs = {} + try: + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + except Exception as e: + warnings.warn(f"Failed to track ONNX scope attributes: {e}. Skipping this step.") for m in model.modules(): m.register_forward_hook(_track_module_attributes_forward_hook)