diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 2536c0a93..e02d8c7cb 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -328,6 +328,7 @@ def llm_load_model( model = model.eval() check_and_mark_quantized_module(model) + handle_generation_config(model) model = _to_model_dtype(model, model_dtype) return model, tokenizer @@ -477,6 +478,7 @@ def mllm_load_model( model = model.eval() check_and_mark_quantized_module(model) + handle_generation_config(model) model = _to_model_dtype(model, model_dtype) return model, processor, tokenizer, image_processor @@ -1549,3 +1551,14 @@ def is_separate_tensor(model: torch.nn.Module, tensor_name: str) -> bool: return True else: return False + + +def handle_generation_config(model): + if hasattr(model, "generation_config"): + generation_config = model.generation_config + if hasattr(generation_config, "top_p") and generation_config.top_p != 1.0: + model.generation_config.do_sample = True + if hasattr(generation_config, "top_k") and generation_config.top_k != 0: + model.generation_config.do_sample = True + if hasattr(generation_config, "temperature") and generation_config.temperature != 1.0: + model.generation_config.do_sample = True