Skip to content

Commit a24a55d

Browse files
author
sanising
committed
Enable guided decoding in vlm generation
Signed-off-by: sanising <sanising@qti.qualcomm.com>
1 parent 251099f commit a24a55d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

QEfficient/generation/vlm_generation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
is_tlm: bool = False,
9393
include_sampler: bool = False,
9494
return_pdfs: bool = False,
95+
include_guided_decoding: bool = False,
9596
sampling_params: Optional[Dict[str, Any]] = None,
9697
):
9798
"""
@@ -111,6 +112,7 @@ def __init__(
111112
is_tlm: Target language model flag
112113
include_sampler: Enable on-device sampling (new feature)
113114
return_pdfs: Return probability distributions
115+
include_guided_decoding: Enable guided decoding in on-device sampling
114116
sampling_params: Sampling parameters for on-device sampling
115117
"""
116118
# Validate required parameters
@@ -134,6 +136,7 @@ def __init__(
134136
is_tlm=is_tlm,
135137
include_sampler=include_sampler,
136138
return_pdfs=return_pdfs,
139+
include_guided_decoding=include_guided_decoding,
137140
sampling_params=sampling_params,
138141
activate=False, # vision components need to be initialized first
139142
)
@@ -305,7 +308,7 @@ def _execute_chunked_prefill(
305308
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
306309

307310
if self.include_sampler:
308-
for op in Constants.SAMPLER_OPS:
311+
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
309312
if decode_batch_id is not None:
310313
lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
311314
else:
@@ -338,7 +341,7 @@ def _execute_chunked_prefill(
338341

339342
if self.include_sampler:
340343
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
341-
for op in Constants.SAMPLER_OPS:
344+
for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()):
342345
chunk_inputs[op] = lang_inputs[op]
343346

344347
outputs = self._session.run(chunk_inputs)
@@ -793,6 +796,7 @@ def generate_stream_tokens(
793796
is_tlm=self.is_tlm,
794797
include_sampler=self.include_sampler,
795798
return_pdfs=self.return_pdfs,
799+
include_guided_decoding=self.include_guided_decoding,
796800
sampling_params=self.sampling_params,
797801
)
798802

0 commit comments

Comments
 (0)