Skip to content

Commit 8d00cb1

Browse files
authored
Fix hash for VLM's language decoder to include qaic_config
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent 3789d5a commit 8d00cb1

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from pathlib import Path
1010
from time import perf_counter
11-
from typing import Dict, List, Optional, Union
11+
from typing import List, Optional, Union
1212

1313
import numpy as np
1414
import torch
@@ -752,19 +752,22 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
752752
]
753753
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
754754

755-
def __init__(self, model, **kwargs):
755+
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
756756
"""
757757
Initializes the language decoder component for multimodal models.
758758
759759
Parameters
760760
----------
761761
model : nn.Module
762762
The full HuggingFace multimodal model from which the language decoder is extracted.
763+
qaic_config : dict, optional
764+
A dictionary for QAIC-specific configurations.
763765
**kwargs :
764766
Additional keyword arguments passed to the base class constructor.
765767
"""
766-
super().__init__(model, **kwargs)
768+
super().__init__(model, qaic_config=qaic_config, **kwargs)
767769
self.model = model.get_qeff_language_decoder()
770+
self.model.qaic_config = qaic_config
768771
self.hash_params["qeff_auto_class"] = self.__class__.__name__
769772

770773
def export(
@@ -936,9 +939,8 @@ def __init__(
936939
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
937940

938941
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
939-
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
942+
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
940943
self.continuous_batching = continuous_batching
941-
self.lang_model.model.qaic_config = qaic_config
942944
self.input_shapes, self.output_names = None, None
943945
# ---Sampling---
944946
# Note: SamplerTransform should be applied after all other transforms
@@ -2286,7 +2288,6 @@ def from_pretrained(
22862288
logger.warning("Updating low_cpu_mem_usage=False")
22872289

22882290
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2289-
22902291
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
22912292
return cls(
22922293
model,
@@ -2365,6 +2366,8 @@ def __init__(
23652366
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
23662367
For Speculative Decoding Target Language Models, this is always True.
23672368
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2369+
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
2370+
during decoding. Only works when include_sampler=True.
23682371
**kwargs :
23692372
Additional keyword arguments passed to the base class constructor.
23702373
@@ -2467,6 +2470,8 @@ def from_pretrained(
24672470
and ``return_pdfs=False`` for regular model.
24682471
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
24692472
The values provided in ``top_ks`` tensor must be less than this maximum limit.
2473+
- **include_guided_decoding** (bool): If True, enables guided token-level filtering
2474+
during decoding. Only works when include_sampler=True.
24702475
24712476
*args :
24722477
Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`.

0 commit comments

Comments
 (0)