Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ def _export(
)
apply_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = True
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
output_names = [
re.sub("_RetainedState", "_InternalRetainedState", name)
if name.endswith("_RetainedState") and ("key" in name or "value" in name)
else name
for name in output_names
]
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
self._onnx_transforms.append(RenameFunctionOutputsTransform)
self._onnx_transforms.append(CustomOpTransform)
Expand Down
19 changes: 11 additions & 8 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,23 +832,26 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
Dynamically determine which DecoderLayer classes should be exported as functions
based on the model's architecture using the existing KVCacheTransform mapping.
"""
# Define patterns that identify decoder layer classes
DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"]

# Get all QEff classes that are decoder layers from the existing mapping
DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"]
decoder_layer_classes = set()

for original_class, qeff_class in KVCacheTransform._module_mapping.items():
# Check if the QEff class name contains decoder layer patterns
qeff_class_name = qeff_class.__name__
if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS):
decoder_layer_classes.add(qeff_class)

# Filter to only include classes that are actually used in the current model
model_decoder_classes = set()
for module in model.modules():
if module.__class__ in decoder_layer_classes:
model_decoder_classes.add(module.__class__)
model_class_name = model.__class__.__name__
if "EncoderWrapper" in model_class_name:
model_decoder_classes.update(
module.__class__ for module in model.modules() if "Qwen2_5_VLVisionBlock" in module.__class__.__name__
)
return model_decoder_classes

model_decoder_classes.update(
module.__class__ for module in model.modules() if module.__class__ in decoder_layer_classes
)

return model_decoder_classes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

mrope_section = mrope_section * 2
cos = cos[position_ids]
sin = sin[position_ids]

cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)

cos = torch.cat([cos[0, ..., 0:32], cos[0, ..., 32:80], cos[0, ..., 80:128]], dim=-1).unsqueeze(0)
sin = torch.cat([sin[0, ..., 0:32], sin[0, ..., 32:80], sin[0, ..., 80:128]], dim=-1).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

Expand Down
10 changes: 8 additions & 2 deletions QEfficient/utils/torch_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

"""Monkey patches for torch.onnx.utils to fix ONNX export issues."""

import warnings

import torch
import torch.onnx.utils as onnx_utils
from torch import _C
Expand Down Expand Up @@ -37,9 +39,13 @@ def _track_module_attributes_forward_hook(module, input, output):
if hasattr(module, attr_name):
onnx_attrs = getattr(module, attr_name)
delattr(module, attr_name)

# FIX: use empty dict to avoid type mismatch
onnx_attrs = {}
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
# onnx_attrs = {}
try:
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
except Exception as e:
warnings.warn(f"Failed to track ONNX scope attributes: {e}. Skipping this step.")

for m in model.modules():
m.register_forward_hook(_track_module_attributes_forward_hook)
Expand Down
Loading