@@ -881,7 +881,10 @@ def __init__(
881881 If `full_batch_size` is provided.
882882 """
883883 if kwargs .pop ("full_batch_size" , None ):
884- raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
884+ continuous_batching = True
885+ warnings .warn (
886+ "full_batch_size argument is deprecated. Use continuous_batching=True instead." , DeprecationWarning , 2
887+ )
885888 self .model = model
886889 self .config = model .config
887890
@@ -1028,7 +1031,7 @@ def export(
10281031 output_names = output_names ["lang" ],
10291032 dynamic_axes = dynamic_axes ["lang" ],
10301033 continuous_batching = self .continuous_batching ,
1031- vocab_size = self .lang_model . model . config .vocab_size ,
1034+ vocab_size = self .config .vocab_size ,
10321035 qaic_config = self .lang_model .model .qaic_config ,
10331036 )
10341037
@@ -1235,6 +1238,7 @@ def generate(
12351238 device_ids : List [int ] = None ,
12361239 runtime_ai100 : bool = True ,
12371240 generation_len : Optional [int ] = None ,
1241+ ** kwargs ,
12381242 ) -> Union [torch .Tensor , np .ndarray ]:
12391243 """
12401244 Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
@@ -1293,6 +1297,7 @@ def generate(
12931297 full_batch_size = fbs ,
12941298 comp_ctx_lengths_prefill = self .comp_ctx_lengths_prefill ,
12951299 comp_ctx_lengths_decode = self .comp_ctx_lengths_decode ,
1300+ ** kwargs ,
12961301 )
12971302
12981303 # Call generate method
@@ -1572,11 +1577,16 @@ def __init__(
15721577 Raises
15731578 ------
15741579 NotImplementedError
1575- If `full_batch_size` is provided.
1580+ If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True .
15761581 """
15771582 if kwargs .pop ("full_batch_size" , None ):
1583+ warnings .warn (
1584+ "full_batch_size argument is deprecated. Use continuous_batching=True instead." , DeprecationWarning , 2
1585+ )
1586+ raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
1587+ if kwargs .pop ("continuous_batching" , None ):
15781588 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
1579- if kwargs .pop ("qaic_config " , None ):
1589+ if qaic_config is not None and qaic_config .pop ("include_sampler " , False ):
15801590 raise NotImplementedError ("On-device sampling is not supported for single QPC multimodal models yet." )
15811591 super ().__init__ (model , ** kwargs )
15821592
0 commit comments