@@ -92,6 +92,7 @@ def __init__(
9292 is_tlm : bool = False ,
9393 include_sampler : bool = False ,
9494 return_pdfs : bool = False ,
95+ include_guided_decoding : bool = False ,
9596 sampling_params : Optional [Dict [str , Any ]] = None ,
9697 ):
9798 """
@@ -111,6 +112,7 @@ def __init__(
111112 is_tlm: Target language model flag
112113 include_sampler: Enable on-device sampling (new feature)
113114 return_pdfs: Return probability distributions
115+ include_guided_decoding: Enable guided decoding in on-device sampling
114116 sampling_params: Sampling parameters for on-device sampling
115117 """
116118 # Validate required parameters
@@ -134,6 +136,7 @@ def __init__(
134136 is_tlm = is_tlm ,
135137 include_sampler = include_sampler ,
136138 return_pdfs = return_pdfs ,
139+ include_guided_decoding = include_guided_decoding ,
137140 sampling_params = sampling_params ,
138141 activate = False , # vision components need to be initialized first
139142 )
@@ -305,7 +308,7 @@ def _execute_chunked_prefill(
305308 lang_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
306309
307310 if self .include_sampler :
308- for op in Constants .SAMPLER_OPS :
311+ for op in Constants .SAMPLER_OPS | ({ "token_bitmasks" } if self . include_guided_decoding else set ()) :
309312 if decode_batch_id is not None :
310313 lang_inputs [op ] = self .sampling_params [op ][decode_batch_id .flatten ()]
311314 else :
@@ -338,7 +341,7 @@ def _execute_chunked_prefill(
338341
339342 if self .include_sampler :
340343 chunk_inputs ["last_accepted_output_tokens" ] = chunk_inputs ["input_ids" ]
341- for op in Constants .SAMPLER_OPS :
344+ for op in Constants .SAMPLER_OPS | ({ "token_bitmasks" } if self . include_guided_decoding else set ()) :
342345 chunk_inputs [op ] = lang_inputs [op ]
343346
344347 outputs = self ._session .run (chunk_inputs )
@@ -793,6 +796,7 @@ def generate_stream_tokens(
793796 is_tlm = self .is_tlm ,
794797 include_sampler = self .include_sampler ,
795798 return_pdfs = self .return_pdfs ,
799+ include_guided_decoding = self .include_guided_decoding ,
796800 sampling_params = self .sampling_params ,
797801 )
798802
0 commit comments