Add handle_generation_config function to manage model generation_config saving failure#1448
Add handle_generation_config function to manage model generation_config saving failure#1448
Conversation
…ig saving failure
There was a problem hiding this comment.
Pull request overview
Adds a helper to adjust generation_config.do_sample based on sampling-related parameters to avoid failures when working with non-default generation settings.
Changes:
- Invoke a new
handle_generation_config()during both LLM and MLLM model load flows. - Add
handle_generation_config()to setdo_sample=Truewhentop_p,top_k, ortemperatureindicates sampling.
|
|
||
| model = model.eval() | ||
| check_and_mark_quantized_module(model) | ||
| handle_generation_config(model) |
There was a problem hiding this comment.
This call mutates model.generation_config during load, which can change downstream generation behavior (enabling sampling) even if the caller did not intend behavior changes at load time. Since the PR goal is to address generation_config saving failures, consider moving this normalization to the save/export path (or applying it only immediately before serialization) to avoid surprising side effects during loading.
| handle_generation_config(model) |
|
|
||
| model = model.eval() | ||
| check_and_mark_quantized_module(model) | ||
| handle_generation_config(model) |
There was a problem hiding this comment.
Same concern as in llm_load_model: mutating generation settings during model load can unexpectedly change runtime generation behavior. If the intent is specifically to avoid GenerationConfig validation errors on save, prefer applying this right before saving rather than at load time.
| handle_generation_config(model) |
| if hasattr(generation_config, "top_k") and generation_config.top_k != 0: | ||
| model.generation_config.do_sample = True | ||
| if hasattr(generation_config, "temperature") and generation_config.temperature != 1.0: | ||
| model.generation_config.do_sample = True |
There was a problem hiding this comment.
If the intent is to prevent GenerationConfig validation/saving failures caused by inconsistent sampling settings, this handling looks incomplete: Transformers' validation can also consider other sampling-related fields (e.g., typical_p, min_p, epsilon_cutoff, eta_cutoff, etc.). With the current implementation, save/validate can still fail when those are set away from defaults while do_sample remains False. Consider expanding the normalization condition to cover all sampling parameters that require do_sample=True.
| model.generation_config.do_sample = True | |
| model.generation_config.do_sample = True | |
| # Additional sampling-related parameters that also imply do_sample=True | |
| if hasattr(generation_config, "typical_p") and generation_config.typical_p is not None: | |
| model.generation_config.do_sample = True | |
| if hasattr(generation_config, "min_p") and generation_config.min_p is not None: | |
| model.generation_config.do_sample = True | |
| if hasattr(generation_config, "epsilon_cutoff") and generation_config.epsilon_cutoff is not None: | |
| model.generation_config.do_sample = True | |
| if hasattr(generation_config, "eta_cutoff") and generation_config.eta_cutoff is not None: | |
| model.generation_config.do_sample = True |
| if hasattr(generation_config, "top_p") and generation_config.top_p != 1.0: | ||
| model.generation_config.do_sample = True | ||
| if hasattr(generation_config, "top_k") and generation_config.top_k != 0: | ||
| model.generation_config.do_sample = True | ||
| if hasattr(generation_config, "temperature") and generation_config.temperature != 1.0: | ||
| model.generation_config.do_sample = True |
There was a problem hiding this comment.
This repeats both the attribute checks and the assignment to model.generation_config.do_sample. Since generation_config is already a local variable, it would be clearer to set generation_config.do_sample once based on a combined condition (e.g., compute a boolean like needs_sampling = ... and then assign once). This reduces duplication and makes it easier to extend the list of parameters consistently.
| if hasattr(generation_config, "top_p") and generation_config.top_p != 1.0: | |
| model.generation_config.do_sample = True | |
| if hasattr(generation_config, "top_k") and generation_config.top_k != 0: | |
| model.generation_config.do_sample = True | |
| if hasattr(generation_config, "temperature") and generation_config.temperature != 1.0: | |
| model.generation_config.do_sample = True | |
| needs_sampling = ( | |
| (hasattr(generation_config, "top_p") and generation_config.top_p != 1.0) | |
| or (hasattr(generation_config, "top_k") and generation_config.top_k != 0) | |
| or ( | |
| hasattr(generation_config, "temperature") | |
| and generation_config.temperature != 1.0 | |
| ) | |
| ) | |
| if needs_sampling: | |
| generation_config.do_sample = True |
|
|
||
| model = model.eval() | ||
| check_and_mark_quantized_module(model) | ||
| handle_generation_config(model) |
There was a problem hiding this comment.
Add TODO with link: huggingface/transformers#43937. Once Transformers has a fix, we can remove the workaround.
Description
huggingface/transformers#43937
Type of Change
Related Issues
Fixes or relates to #
Checklist Before Submitting