@@ -909,7 +909,7 @@ def __init__(
909909 self ,
910910 model : nn .Module ,
911911 continuous_batching : bool = False ,
912- qaic_config : Optional [ dict ] = None ,
912+ ccl_enabled : bool = False ,
913913 ** kwargs ,
914914 ):
915915 """
@@ -932,11 +932,10 @@ def __init__(
932932 self .model = model
933933 self .config = model .config
934934
935- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
936-
937935 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
938936 self .lang_model = QEffCausalLMForTextImageToTextModel (model , ** kwargs )
939937 self .continuous_batching = continuous_batching
938+ self .ccl_enabled = ccl_enabled
940939 self .input_shapes , self .output_names = None , None
941940
942941 @property
@@ -955,7 +954,7 @@ def model_name(self) -> str:
955954 return mname
956955
957956 @classmethod
958- def from_pretrained (cls , pretrained_model_name_or_path : str , qaic_config : Optional [ dict ] = None , ** kwargs ):
957+ def from_pretrained (cls , pretrained_model_name_or_path : str , ** kwargs ):
959958 """
960959 Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path.
961960
@@ -980,11 +979,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
980979 logger .warning ("Updating low_cpu_mem_usage=False" )
981980
982981 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
982+ ccl_enabled = kwargs .pop ("ccl_enabled" , None )
983+
983984 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
984985 return cls (
985986 model ,
986987 pretrained_model_name_or_path = pretrained_model_name_or_path ,
987- qaic_config = qaic_config ,
988+ ccl_enabled = ccl_enabled ,
988989 ** kwargs ,
989990 )
990991
@@ -1090,6 +1091,8 @@ def compile(
10901091 compile_dir : Optional [str ] = None ,
10911092 * ,
10921093 prefill_seq_len : Optional [int ] = None ,
1094+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1095+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
10931096 ctx_len : Optional [int ] = None ,
10941097 batch_size : int = 1 ,
10951098 full_batch_size : Optional [int ] = None ,
@@ -1174,10 +1177,21 @@ def compile(
11741177
11751178 output_names = self .model .get_output_names (kv_offload = True )
11761179
1180+ # if ccl_enabled is True read Compute-Context-Length lists
1181+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
1182+ if self .ccl_enabled :
1183+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
1184+ logger .warning (
1185+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1186+ )
1187+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1188+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1189+ )
1190+
11771191 # For supporting VLLM and Disaggregated with CCL
1178- if " comp_ctx_lengths_prefill" in compiler_options :
1179- self .comp_ctx_lengths_prefill = compiler_options . pop ( " comp_ctx_lengths_prefill" )
1180- self .comp_ctx_lengths_decode = compiler_options . pop ( " comp_ctx_lengths_decode" )
1192+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1193+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1194+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
11811195
11821196 specializations , compiler_options = self .model .get_specializations (
11831197 batch_size = batch_size ,
@@ -1600,7 +1614,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
16001614 def __init__ (
16011615 self ,
16021616 model : nn .Module ,
1603- qaic_config : Optional [ dict ] = None ,
1617+ ccl_enabled : bool = False ,
16041618 ** kwargs ,
16051619 ):
16061620 """
@@ -1622,8 +1636,6 @@ def __init__(
16221636 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
16231637 super ().__init__ (model , ** kwargs )
16241638
1625- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
1626-
16271639 # to handle internvl models
16281640 if hasattr (self .model .config , "llm_config" ) and hasattr (self .model .config , "vision_config" ):
16291641 self .model .config .llm_config .use_cache = True
@@ -1635,12 +1647,12 @@ def __init__(
16351647 else :
16361648 self .model .config .use_cache = True
16371649 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
1650+ self .ccl_enabled = ccl_enabled
16381651
16391652 @classmethod
16401653 def from_pretrained (
16411654 cls ,
16421655 pretrained_model_name_or_path ,
1643- qaic_config : Optional [dict ] = None ,
16441656 * args ,
16451657 ** kwargs ,
16461658 ):
@@ -1671,6 +1683,8 @@ def from_pretrained(
16711683 logger .warning ("Updating low_cpu_mem_usage=False" )
16721684
16731685 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
1686+ ccl_enabled = kwargs .pop ("ccl_enabled" , None )
1687+
16741688 from transformers import AutoConfig
16751689
16761690 config = AutoConfig .from_pretrained (pretrained_model_name_or_path , trust_remote_code = True )
@@ -1681,7 +1695,7 @@ def from_pretrained(
16811695 return cls (
16821696 model ,
16831697 pretrained_model_name_or_path = pretrained_model_name_or_path ,
1684- qaic_config = qaic_config ,
1698+ ccl_enabled = ccl_enabled ,
16851699 ** kwargs ,
16861700 )
16871701
@@ -1725,6 +1739,8 @@ def compile(
17251739 * ,
17261740 prefill_seq_len : Optional [int ] = None ,
17271741 ctx_len : Optional [int ] = None ,
1742+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1743+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
17281744 batch_size : int = 1 ,
17291745 full_batch_size : Optional [int ] = None ,
17301746 kv_cache_batch_size : Optional [int ] = None ,
@@ -1794,10 +1810,21 @@ def compile(
17941810 kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
17951811 output_names = self .model .get_output_names ()
17961812
1813+ # if ccl_enabled is True read Compute-Context-Length lists
1814+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
1815+ if self .ccl_enabled :
1816+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
1817+ logger .warning (
1818+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1819+ )
1820+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1821+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1822+ )
1823+
17971824 # For supporting VLLM and Disaggregated with CCL
1798- if " comp_ctx_lengths_prefill" in compiler_options :
1799- self .comp_ctx_lengths_prefill = compiler_options . pop ( " comp_ctx_lengths_prefill" )
1800- self .comp_ctx_lengths_decode = compiler_options . pop ( " comp_ctx_lengths_decode" )
1825+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1826+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1827+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
18011828
18021829 # Get specializations from modelling file
18031830 # TODO: expose this via the auto class as well
@@ -2180,7 +2207,7 @@ def __new__(
21802207 model : nn .Module ,
21812208 kv_offload : Optional [bool ] = True ,
21822209 continuous_batching : bool = False ,
2183- qaic_config : Optional [ dict ] = None ,
2210+ ccl_enabled : bool = False ,
21842211 ** kwargs ,
21852212 ):
21862213 """
@@ -2204,10 +2231,10 @@ def __new__(
22042231 """
22052232 if kv_offload :
22062233 return _QEffAutoModelForImageTextToTextDualQPC (
2207- model , continuous_batching , qaic_config = qaic_config , ** kwargs
2234+ model , continuous_batching , ccl_enabled = ccl_enabled , ** kwargs
22082235 )
22092236 else :
2210- return _QEFFAutoModelForImageTextToTextSingleQPC (model , qaic_config = qaic_config , ** kwargs )
2237+ return _QEFFAutoModelForImageTextToTextSingleQPC (model , ccl_enabled = ccl_enabled , ** kwargs )
22112238
22122239 @classmethod
22132240 @with_replaced_quantizers
@@ -2257,14 +2284,15 @@ def from_pretrained(
22572284 logger .warning ("Updating low_cpu_mem_usage=False" )
22582285
22592286 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
2287+ ccl_enabled = kwargs .pop ("ccl_enabled" , None )
22602288
22612289 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
22622290 return cls (
22632291 model ,
22642292 kv_offload = kv_offload ,
22652293 continuous_batching = continuous_batching ,
22662294 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2267- qaic_config = qaic_config ,
2295+ ccl_enabled = ccl_enabled ,
22682296 ** kwargs ,
22692297 )
22702298
@@ -2317,6 +2345,7 @@ def __init__(
23172345 model : nn .Module ,
23182346 continuous_batching : bool = False ,
23192347 qaic_config : Optional [dict ] = None ,
2348+ ccl_enabled : bool = False ,
23202349 ** kwargs ,
23212350 ):
23222351 """
@@ -2363,8 +2392,6 @@ def __init__(
23632392 # Set use_cache=True to get KV values as output during ONNX export
23642393 model .config .use_cache = True
23652394
2366- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (qaic_config )
2367-
23682395 super ().__init__ (model , qaic_config = qaic_config , ** kwargs )
23692396 self .num_layers = model .config .num_hidden_layers
23702397 self .continuous_batching = continuous_batching
@@ -2373,6 +2400,7 @@ def __init__(
23732400 self .is_tlm = transformed
23742401
23752402 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
2403+ self .ccl_enabled = ccl_enabled
23762404
23772405 # ---Sampling---
23782406 # Note: SamplerTransform should be applied after all other transforms
@@ -2465,6 +2493,7 @@ def from_pretrained(
24652493 logger .warning ("Updating low_cpu_mem_usage=False" )
24662494
24672495 kv_offload = kwargs .pop ("kv_offload" , None )
2496+ ccl_enabled = kwargs .pop ("ccl_enabled" , None )
24682497
24692498 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
24702499 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , * args , ** kwargs )
@@ -2478,14 +2507,15 @@ def from_pretrained(
24782507 model ,
24792508 kv_offload = kv_offload ,
24802509 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2481- qaic_config = qaic_config ,
2510+ ccl_enabled = ccl_enabled ,
24822511 ** kwargs ,
24832512 )
24842513 return cls (
24852514 model ,
24862515 continuous_batching = continuous_batching ,
24872516 qaic_config = qaic_config ,
24882517 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2518+ ccl_enabled = ccl_enabled ,
24892519 ** kwargs ,
24902520 )
24912521
@@ -2814,6 +2844,8 @@ def compile(
28142844 * ,
28152845 prefill_seq_len : int = 32 ,
28162846 ctx_len : int = 128 ,
2847+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
2848+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
28172849 batch_size : int = 1 ,
28182850 full_batch_size : Optional [int ] = None ,
28192851 kv_cache_batch_size : Optional [int ] = None ,
@@ -2905,10 +2937,19 @@ def compile(
29052937
29062938 """
29072939
2940+ # if ccl_enabled is True read Compute-Context-Length lists
2941+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
2942+ if self .ccl_enabled :
2943+ if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None :
2944+ logger .warning (
2945+ "Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
2946+ )
2947+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
2948+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
2949+ )
2950+
29082951 # For supporting VLLM and Disaggregated with CCL
2909- if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options :
2910- comp_ctx_lengths_prefill = compiler_options .pop ("comp_ctx_lengths_prefill" )
2911- comp_ctx_lengths_decode = compiler_options .pop ("comp_ctx_lengths_decode" )
2952+ if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
29122953 if isinstance (comp_ctx_lengths_prefill , str ):
29132954 import ast
29142955
0 commit comments