@@ -721,14 +721,17 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721 ]
722722 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
723723
724- def __init__ (self , model , qaic_config : Optional [dict ] = None , ** kwargs ):
724+ def __init__ (self , model , continuous_batching : bool = False , qaic_config : Optional [dict ] = None , ** kwargs ):
725725 """
726726 Initializes the language decoder component for multimodal models.
727727
728728 Parameters
729729 ----------
730730 model : nn.Module
731731 The full HuggingFace multimodal model from which the language decoder is extracted.
732+ continuous_batching : bool, optional
733+ If True, enables continuous batching mode for future compilation and execution.
734+ This setting must be consistent across `from_pretrained` and `compile` calls. Default is False.
732735 qaic_config : dict, optional
733736 A dictionary for QAIC-specific configurations.
734737 Only the following keys are supported by the text model of the dual QPC multimodal model:
@@ -741,6 +744,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
741744 super ().__init__ (model , ** kwargs )
742745 self .model = model .get_qeff_language_decoder ()
743746 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
747+ self .continuous_batching = continuous_batching
744748 self .model .qaic_config = qaic_config
745749 # ---Sampling---
746750 # Note: SamplerTransform should be applied after all other transforms
@@ -804,6 +808,7 @@ def get_sampling_inputs_and_outputs(
804808 sampling-related parameters.
805809 """
806810 bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
811+ fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
807812
808813 assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
809814
@@ -816,10 +821,10 @@ def get_sampling_inputs_and_outputs(
816821 dynamic_axes ["last_accepted_output_tokens" ] = {0 : "batch_size" , 1 : "seq_len" }
817822
818823 example_inputs ["past_repetition_penalty_buffer" ] = torch .zeros (
819- (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
824+ (fbs if self . continuous_batching else bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
820825 )
821826 dynamic_axes ["past_repetition_penalty_buffer" ] = {
822- 0 : "batch_size" ,
827+ 0 : "full_batch_size" if self . continuous_batching else " batch_size" ,
823828 }
824829 output_names .append ("past_repetition_penalty_buffer_RetainedState" )
825830
@@ -829,10 +834,10 @@ def get_sampling_inputs_and_outputs(
829834 dynamic_axes ["repetition_penalties" ] = {0 : "batch_size" }
830835
831836 example_inputs ["past_presence_penalty_buffer" ] = torch .zeros (
832- (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
837+ (fbs if self . continuous_batching else bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
833838 )
834839 dynamic_axes ["past_presence_penalty_buffer" ] = {
835- 0 : "batch_size" ,
840+ 0 : "full_batch_size" if self . continuous_batching else " batch_size" ,
836841 }
837842 output_names .append ("past_presence_penalty_buffer_RetainedState" )
838843
@@ -981,7 +986,7 @@ def __init__(
981986 self .model = model
982987 self .config = model .config
983988 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
984- self .lang_model = QEffCausalLMForTextImageToTextModel (model , ** kwargs )
989+ self .lang_model = QEffCausalLMForTextImageToTextModel (model , continuous_batching = continuous_batching , ** kwargs )
985990 self .continuous_batching = continuous_batching
986991 self .input_shapes , self .output_names = None , None
987992
0 commit comments