Skip to content

Commit 7cf106e

Browse files
committed
Refactor
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent d02d04d commit 7cf106e

File tree

2 files changed

+114
-186
lines changed

2 files changed

+114
-186
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 24 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
)
6262
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
6363
from QEfficient.utils.logging_utils import logger
64+
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs
6465

6566

6667
class QEFFTransformersBase(QEFFBaseModel):
@@ -730,28 +731,12 @@ def __init__(self, model, continuous_batching: bool = False, qaic_config: Option
730731
----------
731732
model : nn.Module
732733
The full HuggingFace multimodal model from which the language decoder is extracted.
733-
continuous_batching : bool, optional
734-
If True, enables continuous batching mode for future compilation and execution.
735-
This setting must be consistent across `from_pretrained` and `compile` calls. Default is False.
736-
qaic_config : dict, optional
737-
A dictionary for QAIC-specific configurations.
738-
Only the following keys are supported by the text model of the dual QPC multimodal model:
739-
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
740-
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
741-
Additional keys will be ignored.
742734
**kwargs :
743735
Additional keyword arguments passed to the base class constructor.
744736
"""
745737
super().__init__(model, **kwargs)
746738
self.model = model.get_qeff_language_decoder()
747739
self.hash_params["qeff_auto_class"] = self.__class__.__name__
748-
self.continuous_batching = continuous_batching
749-
self.model.qaic_config = qaic_config
750-
# ---Sampling---
751-
# Note: SamplerTransform should be applied after all other transforms
752-
# are done. The role of the sampler is to just add nodes at the output of the
753-
# previous transform function.
754-
self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs)
755740

756741
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
757742
"""
@@ -775,98 +760,10 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
775760
str
776761
Path to the generated ONNX graph file for the language decoder.
777762
"""
778-
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
779-
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
780-
inputs, output_names, dynamic_axes
781-
)
782763
return self._export(
783764
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
784765
)
785766

786-
def get_sampling_inputs_and_outputs(
787-
self,
788-
example_inputs: Dict[str, torch.Tensor],
789-
output_names: List[str],
790-
dynamic_axes: Dict[str, Dict[int, str]],
791-
):
792-
"""
793-
Updates the example inputs, output names, and dynamic axes to include
794-
parameters relevant for on-device sampling during ONNX export.
795-
796-
Parameters
797-
----------
798-
example_inputs : Dict[str, torch.Tensor]
799-
Current dictionary of example inputs.
800-
output_names : List[str]
801-
Current list of output names.
802-
dynamic_axes : Dict[str, Dict[int, str]]
803-
Current dictionary of dynamic axes configurations.
804-
805-
Returns
806-
-------
807-
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
808-
Updated example inputs, output names, and dynamic axes including
809-
sampling-related parameters.
810-
"""
811-
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
812-
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
813-
814-
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
815-
816-
logits_index = output_names.index("logits")
817-
output_names[logits_index] = "next_tokens"
818-
819-
example_inputs["last_accepted_output_tokens"] = torch.zeros(
820-
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
821-
)
822-
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
823-
824-
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
825-
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
826-
)
827-
dynamic_axes["past_repetition_penalty_buffer"] = {
828-
0: "full_batch_size" if self.continuous_batching else "batch_size",
829-
}
830-
output_names.append("past_repetition_penalty_buffer_RetainedState")
831-
832-
example_inputs["repetition_penalties"] = (
833-
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
834-
)
835-
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
836-
837-
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
838-
(fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool
839-
)
840-
dynamic_axes["past_presence_penalty_buffer"] = {
841-
0: "full_batch_size" if self.continuous_batching else "batch_size",
842-
}
843-
output_names.append("past_presence_penalty_buffer_RetainedState")
844-
845-
example_inputs["presence_penalties"] = (
846-
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
847-
)
848-
dynamic_axes["presence_penalties"] = {0: "batch_size"}
849-
850-
example_inputs["temperatures"] = (
851-
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
852-
)
853-
dynamic_axes["temperatures"] = {0: "batch_size"}
854-
855-
max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
856-
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
857-
dynamic_axes["top_ks"] = {0: "batch_size"}
858-
859-
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
860-
dynamic_axes["top_ps"] = {0: "batch_size"}
861-
862-
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
863-
dynamic_axes["min_ps"] = {0: "batch_size"}
864-
865-
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
866-
dynamic_axes["random_numbers"] = {0: "batch_size"}
867-
868-
return example_inputs, output_names, dynamic_axes
869-
870767
def compile(
871768
self,
872769
compile_dir,
@@ -993,7 +890,13 @@ def __init__(
993890
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
994891
self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs)
995892
self.continuous_batching = continuous_batching
893+
self.lang_model.model.qaic_config = qaic_config
996894
self.input_shapes, self.output_names = None, None
895+
# ---Sampling---
896+
# Note: SamplerTransform should be applied after all other transforms
897+
# are done. The role of the sampler is to just add nodes at the output of the
898+
# previous transform function.
899+
self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs)
997900

998901
@property
999902
def model_name(self) -> str:
@@ -1115,6 +1018,19 @@ def export(
11151018
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
11161019
)
11171020
output_names = self.model.get_output_names(kv_offload=True)
1021+
if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
1022+
"include_sampler", False
1023+
):
1024+
logits_index = output_names["lang"].index("logits")
1025+
output_names["lang"][logits_index] = "next_tokens"
1026+
inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs(
1027+
example_inputs=inputs["lang"],
1028+
output_names=output_names["lang"],
1029+
dynamic_axes=dynamic_axes["lang"],
1030+
continuous_batching=self.continuous_batching,
1031+
vocab_size=self.lang_model.model.config.vocab_size,
1032+
qaic_config=self.lang_model.model.qaic_config,
1033+
)
11181034

11191035
self.vision_model.export(
11201036
inputs["vision"],
@@ -2300,7 +2216,6 @@ def from_pretrained(
23002216
model,
23012217
kv_offload=kv_offload,
23022218
continuous_batching=continuous_batching,
2303-
qaic_config=qaic_config,
23042219
pretrained_model_name_or_path=pretrained_model_name_or_path,
23052220
qaic_config=qaic_config,
23062221
**kwargs,
@@ -2634,10 +2549,13 @@ def export(self, export_dir: Optional[str] = None) -> str:
26342549
dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}
26352550

26362551
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
2637-
example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
2552+
example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs(
26382553
example_inputs=example_inputs,
26392554
output_names=output_names,
26402555
dynamic_axes=dynamic_axes,
2556+
continuous_batching=self.continuous_batching,
2557+
vocab_size=self.model.config.vocab_size,
2558+
qaic_config=self.model.qaic_config,
26412559
)
26422560

26432561
return self._export(
@@ -2647,85 +2565,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
26472565
export_dir=export_dir,
26482566
)
26492567

2650-
def get_sampling_inputs_and_outputs(
2651-
self,
2652-
example_inputs: Dict[str, torch.Tensor],
2653-
output_names: List[str],
2654-
dynamic_axes: Dict[str, Dict[int, str]],
2655-
):
2656-
"""
2657-
Updates the example inputs, output names, and dynamic axes to include
2658-
parameters relevant for on-device sampling during ONNX export.
2659-
2660-
Parameters
2661-
----------
2662-
example_inputs : Dict[str, torch.Tensor]
2663-
Current dictionary of example inputs.
2664-
output_names : List[str]
2665-
Current list of output names.
2666-
dynamic_axes : Dict[str, Dict[int, str]]
2667-
Current dictionary of dynamic axes configurations.
2668-
2669-
Returns
2670-
-------
2671-
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
2672-
Updated example inputs, output names, and dynamic axes including
2673-
sampling-related parameters.
2674-
"""
2675-
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
2676-
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
2677-
2678-
example_inputs["last_accepted_output_tokens"] = torch.zeros(
2679-
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
2680-
)
2681-
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
2682-
2683-
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
2684-
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
2685-
)
2686-
dynamic_axes["past_repetition_penalty_buffer"] = {
2687-
0: "full_batch_size" if self.continuous_batching else "batch_size",
2688-
}
2689-
output_names.append("past_repetition_penalty_buffer_RetainedState")
2690-
2691-
example_inputs["repetition_penalties"] = (
2692-
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
2693-
)
2694-
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
2695-
2696-
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
2697-
(fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool
2698-
)
2699-
dynamic_axes["past_presence_penalty_buffer"] = {
2700-
0: "full_batch_size" if self.continuous_batching else "batch_size",
2701-
}
2702-
output_names.append("past_presence_penalty_buffer_RetainedState")
2703-
2704-
example_inputs["presence_penalties"] = (
2705-
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
2706-
)
2707-
dynamic_axes["presence_penalties"] = {0: "batch_size"}
2708-
2709-
example_inputs["temperatures"] = (
2710-
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
2711-
)
2712-
dynamic_axes["temperatures"] = {0: "batch_size"}
2713-
2714-
max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
2715-
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
2716-
dynamic_axes["top_ks"] = {0: "batch_size"}
2717-
2718-
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
2719-
dynamic_axes["top_ps"] = {0: "batch_size"}
2720-
2721-
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
2722-
dynamic_axes["min_ps"] = {0: "batch_size"}
2723-
2724-
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
2725-
dynamic_axes["random_numbers"] = {0: "batch_size"}
2726-
2727-
return example_inputs, output_names, dynamic_axes
2728-
27292568
def build_prefill_specialization(
27302569
self,
27312570
prefill_seq_len: int = 32,

QEfficient/utils/sampler_utils.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import Optional, Set
8+
from typing import Dict, List, Optional, Set
99

10+
import torch
11+
12+
from QEfficient.utils import constants
1013
from QEfficient.utils.constants import Constants
1114
from QEfficient.utils.logging_utils import logger
1215

@@ -56,3 +59,89 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[
5659
)
5760

5861
return session_includes_sampler
62+
63+
64+
def get_sampling_inputs_and_outputs(
65+
example_inputs: Dict[str, torch.Tensor],
66+
output_names: List[str],
67+
dynamic_axes: Dict[str, Dict[int, str]],
68+
continuous_batching: bool,
69+
vocab_size: int,
70+
qaic_config: Dict,
71+
):
72+
"""
73+
Updates the example inputs, output names, and dynamic axes to include
74+
parameters relevant for on-device sampling during ONNX export.
75+
76+
Parameters
77+
----------
78+
example_inputs : Dict[str, torch.Tensor]
79+
Current dictionary of example inputs.
80+
output_names : List[str]
81+
Current list of output names.
82+
dynamic_axes : Dict[str, Dict[int, str]]
83+
Current dictionary of dynamic axes configurations.
84+
continuous_batching : bool
85+
Whether this model will be used for continuous batching in the future.
86+
vocab_size: int
87+
Vocabulary size for this model.
88+
qaic_config : Dict
89+
QAIC config dictionary.
90+
91+
Returns
92+
-------
93+
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
94+
Updated example inputs, output names, and dynamic axes including
95+
sampling-related parameters.
96+
"""
97+
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
98+
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
99+
100+
example_inputs["last_accepted_output_tokens"] = torch.zeros(
101+
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
102+
)
103+
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
104+
105+
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
106+
(fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
107+
)
108+
dynamic_axes["past_repetition_penalty_buffer"] = {
109+
0: "full_batch_size" if continuous_batching else "batch_size",
110+
}
111+
output_names.append("past_repetition_penalty_buffer_RetainedState")
112+
113+
example_inputs["repetition_penalties"] = (
114+
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
115+
)
116+
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
117+
118+
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
119+
(fbs if continuous_batching else bs, vocab_size), dtype=torch.bool
120+
)
121+
dynamic_axes["past_presence_penalty_buffer"] = {
122+
0: "full_batch_size" if continuous_batching else "batch_size",
123+
}
124+
output_names.append("past_presence_penalty_buffer_RetainedState")
125+
126+
example_inputs["presence_penalties"] = (
127+
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
128+
)
129+
dynamic_axes["presence_penalties"] = {0: "batch_size"}
130+
131+
example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
132+
dynamic_axes["temperatures"] = {0: "batch_size"}
133+
134+
max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
135+
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
136+
dynamic_axes["top_ks"] = {0: "batch_size"}
137+
138+
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
139+
dynamic_axes["top_ps"] = {0: "batch_size"}
140+
141+
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
142+
dynamic_axes["min_ps"] = {0: "batch_size"}
143+
144+
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
145+
dynamic_axes["random_numbers"] = {0: "batch_size"}
146+
147+
return example_inputs, output_names, dynamic_axes

0 commit comments

Comments
 (0)