diff --git a/dflash/model.py b/dflash/model.py index ef8d497..d64e838 100644 --- a/dflash/model.py +++ b/dflash/model.py @@ -118,7 +118,7 @@ def dflash_generate( is_causal=False, )[:, 1 - block_size :, :]) past_key_values_draft.crop(start) - block_output_ids[:, 1:] = sample(draft_logits) + block_output_ids[:, 1:] = sample(draft_logits, temperature) if draft_prefill and return_stats: draft_prefill = False decode_start = _cuda_time()