6161)
6262from QEfficient .utils .check_ccl_specializations import process_ccl_specializations
6363from QEfficient .utils .logging_utils import logger
64+ from QEfficient .utils .sampler_utils import get_sampling_inputs_and_outputs
6465
6566
6667class QEFFTransformersBase (QEFFBaseModel ):
@@ -730,28 +731,12 @@ def __init__(self, model, continuous_batching: bool = False, qaic_config: Option
730731 ----------
731732 model : nn.Module
732733 The full HuggingFace multimodal model from which the language decoder is extracted.
733- continuous_batching : bool, optional
734- If True, enables continuous batching mode for future compilation and execution.
735- This setting must be consistent across `from_pretrained` and `compile` calls. Default is False.
736- qaic_config : dict, optional
737- A dictionary for QAIC-specific configurations.
738- Only the following keys are supported by the text model of the dual QPC multimodal model:
739- - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
740- - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
741- Additional keys will be ignored.
742734 **kwargs :
743735 Additional keyword arguments passed to the base class constructor.
744736 """
745737 super ().__init__ (model , ** kwargs )
746738 self .model = model .get_qeff_language_decoder ()
747739 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
748- self .continuous_batching = continuous_batching
749- self .model .qaic_config = qaic_config
750- # ---Sampling---
751- # Note: SamplerTransform should be applied after all other transforms
752- # are done. The role of the sampler is to just add nodes at the output of the
753- # previous transform function.
754- self .model , _ = SamplerTransform .apply (self .model , qaic_config , ** kwargs )
755740
756741 def export (self , inputs , output_names , dynamic_axes , export_dir = None , offload_pt_weights = True ):
757742 """
@@ -775,98 +760,10 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
775760 str
776761 Path to the generated ONNX graph file for the language decoder.
777762 """
778- if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
779- inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (
780- inputs , output_names , dynamic_axes
781- )
782763 return self ._export (
783764 inputs , output_names , dynamic_axes , export_dir = export_dir , offload_pt_weights = offload_pt_weights
784765 )
785766
786- def get_sampling_inputs_and_outputs (
787- self ,
788- example_inputs : Dict [str , torch .Tensor ],
789- output_names : List [str ],
790- dynamic_axes : Dict [str , Dict [int , str ]],
791- ):
792- """
793- Updates the example inputs, output names, and dynamic axes to include
794- parameters relevant for on-device sampling during ONNX export.
795-
796- Parameters
797- ----------
798- example_inputs : Dict[str, torch.Tensor]
799- Current dictionary of example inputs.
800- output_names : List[str]
801- Current list of output names.
802- dynamic_axes : Dict[str, Dict[int, str]]
803- Current dictionary of dynamic axes configurations.
804-
805- Returns
806- -------
807- Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
808- Updated example inputs, output names, and dynamic axes including
809- sampling-related parameters.
810- """
811- bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
812- fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
813-
814- assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
815-
816- logits_index = output_names .index ("logits" )
817- output_names [logits_index ] = "next_tokens"
818-
819- example_inputs ["last_accepted_output_tokens" ] = torch .zeros (
820- (bs , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ), dtype = torch .int64
821- )
822- dynamic_axes ["last_accepted_output_tokens" ] = {0 : "batch_size" , 1 : "seq_len" }
823-
824- example_inputs ["past_repetition_penalty_buffer" ] = torch .zeros (
825- (fbs if self .continuous_batching else bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
826- )
827- dynamic_axes ["past_repetition_penalty_buffer" ] = {
828- 0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
829- }
830- output_names .append ("past_repetition_penalty_buffer_RetainedState" )
831-
832- example_inputs ["repetition_penalties" ] = (
833- torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
834- )
835- dynamic_axes ["repetition_penalties" ] = {0 : "batch_size" }
836-
837- example_inputs ["past_presence_penalty_buffer" ] = torch .zeros (
838- (fbs if self .continuous_batching else bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
839- )
840- dynamic_axes ["past_presence_penalty_buffer" ] = {
841- 0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
842- }
843- output_names .append ("past_presence_penalty_buffer_RetainedState" )
844-
845- example_inputs ["presence_penalties" ] = (
846- torch .zeros ((bs , 1 ), dtype = torch .float ) + constants .ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
847- )
848- dynamic_axes ["presence_penalties" ] = {0 : "batch_size" }
849-
850- example_inputs ["temperatures" ] = (
851- torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TEMPERATURES
852- )
853- dynamic_axes ["temperatures" ] = {0 : "batch_size" }
854-
855- max_top_k_ids = self .model .qaic_config .get ("max_top_k_ids" , constants .ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS )
856- example_inputs ["top_ks" ] = torch .randint (1 , max_top_k_ids , size = (bs , 1 )).to (torch .int32 )
857- dynamic_axes ["top_ks" ] = {0 : "batch_size" }
858-
859- example_inputs ["top_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TOP_PS
860- dynamic_axes ["top_ps" ] = {0 : "batch_size" }
861-
862- example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
863- dynamic_axes ["min_ps" ] = {0 : "batch_size" }
864-
865- example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
866- dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
867-
868- return example_inputs , output_names , dynamic_axes
869-
870767 def compile (
871768 self ,
872769 compile_dir ,
@@ -993,7 +890,13 @@ def __init__(
993890 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
994891 self .lang_model = QEffCausalLMForTextImageToTextModel (model , continuous_batching = continuous_batching , ** kwargs )
995892 self .continuous_batching = continuous_batching
893+ self .lang_model .model .qaic_config = qaic_config
996894 self .input_shapes , self .output_names = None , None
895+ # ---Sampling---
896+ # Note: SamplerTransform should be applied after all other transforms
897+ # are done. The role of the sampler is to just add nodes at the output of the
898+ # previous transform function.
899+ self .lang_model .model , _ = SamplerTransform .apply (self .lang_model .model , qaic_config , ** kwargs )
997900
998901 @property
999902 def model_name (self ) -> str :
@@ -1115,6 +1018,19 @@ def export(
11151018 kv_offload = True , comp_ctx_lengths = self .comp_ctx_lengths_decode
11161019 )
11171020 output_names = self .model .get_output_names (kv_offload = True )
1021+ if self .lang_model .model .qaic_config is not None and self .lang_model .model .qaic_config .get (
1022+ "include_sampler" , False
1023+ ):
1024+ logits_index = output_names ["lang" ].index ("logits" )
1025+ output_names ["lang" ][logits_index ] = "next_tokens"
1026+ inputs ["lang" ], output_names ["lang" ], dynamic_axes ["lang" ] = get_sampling_inputs_and_outputs (
1027+ example_inputs = inputs ["lang" ],
1028+ output_names = output_names ["lang" ],
1029+ dynamic_axes = dynamic_axes ["lang" ],
1030+ continuous_batching = self .continuous_batching ,
1031+ vocab_size = self .lang_model .model .config .vocab_size ,
1032+ qaic_config = self .lang_model .model .qaic_config ,
1033+ )
11181034
11191035 self .vision_model .export (
11201036 inputs ["vision" ],
@@ -2300,7 +2216,6 @@ def from_pretrained(
23002216 model ,
23012217 kv_offload = kv_offload ,
23022218 continuous_batching = continuous_batching ,
2303- qaic_config = qaic_config ,
23042219 pretrained_model_name_or_path = pretrained_model_name_or_path ,
23052220 qaic_config = qaic_config ,
23062221 ** kwargs ,
@@ -2634,10 +2549,13 @@ def export(self, export_dir: Optional[str] = None) -> str:
26342549 dynamic_axes ["num_logits_to_keep" ] = {0 : "num_logits_to_keep" }
26352550
26362551 if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
2637- example_inputs , output_names , dynamic_axes = self . get_sampling_inputs_and_outputs (
2552+ example_inputs , output_names , dynamic_axes = get_sampling_inputs_and_outputs (
26382553 example_inputs = example_inputs ,
26392554 output_names = output_names ,
26402555 dynamic_axes = dynamic_axes ,
2556+ continuous_batching = self .continuous_batching ,
2557+ vocab_size = self .model .config .vocab_size ,
2558+ qaic_config = self .model .qaic_config ,
26412559 )
26422560
26432561 return self ._export (
@@ -2647,85 +2565,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
26472565 export_dir = export_dir ,
26482566 )
26492567
2650- def get_sampling_inputs_and_outputs (
2651- self ,
2652- example_inputs : Dict [str , torch .Tensor ],
2653- output_names : List [str ],
2654- dynamic_axes : Dict [str , Dict [int , str ]],
2655- ):
2656- """
2657- Updates the example inputs, output names, and dynamic axes to include
2658- parameters relevant for on-device sampling during ONNX export.
2659-
2660- Parameters
2661- ----------
2662- example_inputs : Dict[str, torch.Tensor]
2663- Current dictionary of example inputs.
2664- output_names : List[str]
2665- Current list of output names.
2666- dynamic_axes : Dict[str, Dict[int, str]]
2667- Current dictionary of dynamic axes configurations.
2668-
2669- Returns
2670- -------
2671- Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
2672- Updated example inputs, output names, and dynamic axes including
2673- sampling-related parameters.
2674- """
2675- bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
2676- fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
2677-
2678- example_inputs ["last_accepted_output_tokens" ] = torch .zeros (
2679- (bs , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ), dtype = torch .int64
2680- )
2681- dynamic_axes ["last_accepted_output_tokens" ] = {0 : "batch_size" , 1 : "seq_len" }
2682-
2683- example_inputs ["past_repetition_penalty_buffer" ] = torch .zeros (
2684- (fbs if self .continuous_batching else bs , self .model .config .vocab_size ), dtype = torch .bool
2685- )
2686- dynamic_axes ["past_repetition_penalty_buffer" ] = {
2687- 0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
2688- }
2689- output_names .append ("past_repetition_penalty_buffer_RetainedState" )
2690-
2691- example_inputs ["repetition_penalties" ] = (
2692- torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
2693- )
2694- dynamic_axes ["repetition_penalties" ] = {0 : "batch_size" }
2695-
2696- example_inputs ["past_presence_penalty_buffer" ] = torch .zeros (
2697- (fbs if self .continuous_batching else bs , self .model .config .vocab_size ), dtype = torch .bool
2698- )
2699- dynamic_axes ["past_presence_penalty_buffer" ] = {
2700- 0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
2701- }
2702- output_names .append ("past_presence_penalty_buffer_RetainedState" )
2703-
2704- example_inputs ["presence_penalties" ] = (
2705- torch .zeros ((bs , 1 ), dtype = torch .float ) + constants .ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
2706- )
2707- dynamic_axes ["presence_penalties" ] = {0 : "batch_size" }
2708-
2709- example_inputs ["temperatures" ] = (
2710- torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TEMPERATURES
2711- )
2712- dynamic_axes ["temperatures" ] = {0 : "batch_size" }
2713-
2714- max_top_k_ids = self .model .qaic_config .get ("max_top_k_ids" , constants .ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS )
2715- example_inputs ["top_ks" ] = torch .randint (1 , max_top_k_ids , size = (bs , 1 )).to (torch .int32 )
2716- dynamic_axes ["top_ks" ] = {0 : "batch_size" }
2717-
2718- example_inputs ["top_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TOP_PS
2719- dynamic_axes ["top_ps" ] = {0 : "batch_size" }
2720-
2721- example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
2722- dynamic_axes ["min_ps" ] = {0 : "batch_size" }
2723-
2724- example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
2725- dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
2726-
2727- return example_inputs , output_names , dynamic_axes
2728-
27292568 def build_prefill_specialization (
27302569 self ,
27312570 prefill_seq_len : int = 32 ,
0 commit comments