Skip to content

Commit 45aed11

Browse files
committed
Add unit tests
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent 7cf106e commit 45aed11

File tree

3 files changed

+142
-31
lines changed

3 files changed

+142
-31
lines changed

QEfficient/generation/vlm_generation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
write_io_files,
3737
)
3838
from QEfficient.utils import LRUCache
39+
from QEfficient.utils.constants import Constants
3940
from QEfficient.utils.logging_utils import logger
4041

4142

@@ -303,6 +304,13 @@ def _execute_chunked_prefill(
303304
prefill_ccl_id = 0
304305
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
305306

307+
if self.include_sampler:
308+
for op in Constants.SAMPLER_OPS:
309+
if decode_batch_id is not None:
310+
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
311+
else:
312+
lang_inputs[op] = self.sampling_params[op]
313+
306314
for i in range(num_chunks):
307315
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
308316
position_ids_slice = lang_inputs["position_ids"][
@@ -328,6 +336,11 @@ def _execute_chunked_prefill(
328336

329337
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
330338

339+
if self.include_sampler:
340+
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
341+
for op in Constants.SAMPLER_OPS:
342+
chunk_inputs[op] = lang_inputs[op]
343+
331344
outputs = self._session.run(chunk_inputs)
332345

333346
if "image_idx_output" in outputs:

QEfficient/transformers/models/modeling_auto.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,10 @@ def __init__(
881881
If `full_batch_size` is provided.
882882
"""
883883
if kwargs.pop("full_batch_size", None):
884-
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
884+
continuous_batching = True
885+
warnings.warn(
886+
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
887+
)
885888
self.model = model
886889
self.config = model.config
887890

@@ -1028,7 +1031,7 @@ def export(
10281031
output_names=output_names["lang"],
10291032
dynamic_axes=dynamic_axes["lang"],
10301033
continuous_batching=self.continuous_batching,
1031-
vocab_size=self.lang_model.model.config.vocab_size,
1034+
vocab_size=self.config.vocab_size,
10321035
qaic_config=self.lang_model.model.qaic_config,
10331036
)
10341037

@@ -1235,6 +1238,7 @@ def generate(
12351238
device_ids: List[int] = None,
12361239
runtime_ai100: bool = True,
12371240
generation_len: Optional[int] = None,
1241+
**kwargs,
12381242
) -> Union[torch.Tensor, np.ndarray]:
12391243
"""
12401244
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
@@ -1293,6 +1297,7 @@ def generate(
12931297
full_batch_size=fbs,
12941298
comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill,
12951299
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
1300+
**kwargs,
12961301
)
12971302

12981303
# Call generate method
@@ -1572,11 +1577,16 @@ def __init__(
15721577
Raises
15731578
------
15741579
NotImplementedError
1575-
If `full_batch_size` is provided.
1580+
If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True.
15761581
"""
15771582
if kwargs.pop("full_batch_size", None):
1583+
warnings.warn(
1584+
"full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
1585+
)
1586+
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
1587+
if kwargs.pop("continuous_batching", None):
15781588
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
1579-
if kwargs.pop("qaic_config", None):
1589+
if qaic_config is not None and qaic_config.pop("include_sampler", False):
15801590
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
15811591
super().__init__(model, **kwargs)
15821592

0 commit comments

Comments
 (0)