|
8 | 8 | import warnings |
9 | 9 | from pathlib import Path |
10 | 10 | from time import perf_counter |
11 | | -from typing import Dict, List, Optional, Union |
| 11 | +from typing import List, Optional, Union |
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 | import torch |
@@ -752,19 +752,22 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): |
752 | 752 | ] |
753 | 753 | _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] |
754 | 754 |
|
755 | | - def __init__(self, model, **kwargs): |
| 755 | + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): |
756 | 756 | """ |
757 | 757 | Initializes the language decoder component for multimodal models. |
758 | 758 |
|
759 | 759 | Parameters |
760 | 760 | ---------- |
761 | 761 | model : nn.Module |
762 | 762 | The full HuggingFace multimodal model from which the language decoder is extracted. |
| 763 | + qaic_config : dict, optional |
| 764 | + A dictionary for QAIC-specific configurations. |
763 | 765 | **kwargs : |
764 | 766 | Additional keyword arguments passed to the base class constructor. |
765 | 767 | """ |
766 | | - super().__init__(model, **kwargs) |
| 768 | + super().__init__(model, qaic_config=qaic_config, **kwargs) |
767 | 769 | self.model = model.get_qeff_language_decoder() |
| 770 | + self.model.qaic_config = qaic_config |
768 | 771 | self.hash_params["qeff_auto_class"] = self.__class__.__name__ |
769 | 772 |
|
770 | 773 | def export( |
@@ -936,9 +939,8 @@ def __init__( |
936 | 939 | self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) |
937 | 940 |
|
938 | 941 | self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) |
939 | | - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) |
| 942 | + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) |
940 | 943 | self.continuous_batching = continuous_batching |
941 | | - self.lang_model.model.qaic_config = qaic_config |
942 | 944 | self.input_shapes, self.output_names = None, None |
943 | 945 | # ---Sampling--- |
944 | 946 | # Note: SamplerTransform should be applied after all other transforms |
@@ -2286,7 +2288,6 @@ def from_pretrained( |
2286 | 2288 | logger.warning("Updating low_cpu_mem_usage=False") |
2287 | 2289 |
|
2288 | 2290 | kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) |
2289 | | - |
2290 | 2291 | model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) |
2291 | 2292 | return cls( |
2292 | 2293 | model, |
@@ -2365,6 +2366,8 @@ def __init__( |
2365 | 2366 | - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. |
2366 | 2367 | For Speculative Decoding Target Language Models, this is always True. |
2367 | 2368 | - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. |
| 2369 | + - **include_guided_decoding** (bool): If True, enables guided token-level filtering |
| 2370 | + during decoding. Only works when include_sampler=True. |
2368 | 2371 | **kwargs : |
2369 | 2372 | Additional keyword arguments passed to the base class constructor. |
2370 | 2373 |
|
@@ -2467,6 +2470,8 @@ def from_pretrained( |
2467 | 2470 | and ``return_pdfs=False`` for regular model. |
2468 | 2471 | - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. |
2469 | 2472 | The values provided in ``top_ks`` tensor must be less than this maximum limit. |
| 2473 | + - **include_guided_decoding** (bool): If True, enables guided token-level filtering |
| 2474 | + during decoding. Only works when include_sampler=True. |
2470 | 2475 |
|
2471 | 2476 | *args : |
2472 | 2477 | Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. |
|
0 commit comments