From 082cccb1b1cc8bb9d3b666cee8429f20c19bc283 Mon Sep 17 00:00:00 2001 From: Xin He Date: Thu, 12 Feb 2026 08:51:42 +0000 Subject: [PATCH] Add handle_generation_config function to manage model generation_config saving failure --- auto_round/utils/model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 2536c0a93..e02d8c7cb 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -328,6 +328,7 @@ def llm_load_model( model = model.eval() check_and_mark_quantized_module(model) + handle_generation_config(model) model = _to_model_dtype(model, model_dtype) return model, tokenizer @@ -477,6 +478,7 @@ def mllm_load_model( model = model.eval() check_and_mark_quantized_module(model) + handle_generation_config(model) model = _to_model_dtype(model, model_dtype) return model, processor, tokenizer, image_processor @@ -1549,3 +1551,14 @@ def is_separate_tensor(model: torch.nn.Module, tensor_name: str) -> bool: return True else: return False + + +def handle_generation_config(model): + if hasattr(model, "generation_config"): + generation_config = model.generation_config + if hasattr(generation_config, "top_p") and generation_config.top_p != 1.0: + model.generation_config.do_sample = True + if hasattr(generation_config, "top_k") and generation_config.top_k != 0: + model.generation_config.do_sample = True + if hasattr(generation_config, "temperature") and generation_config.temperature != 1.0: + model.generation_config.do_sample = True