diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..75a95ba 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -126,6 +126,54 @@ def build_common_training_kwargs( ) +def sanitize_generation_config(trainer: Any) -> None: + """Fix inconsistent ``generation_config`` so checkpoint saves don't fail. + + Some upstream models (notably Olmo-3 Think variants) ship a + ``generation_config.json`` that sets sampling-only parameters + (``temperature``, ``top_p``) while leaving ``do_sample=False``. This is + benign at training time — we never call ``model.generate`` — but + ``transformers >= 5.x`` runs strict validation inside + ``GenerationConfig.save_pretrained`` and refuses to write the file:: + + ValueError: GenerationConfig is invalid: + - `temperature` is set to 0.6 -- this flag is only used in + sample-based generation modes. You should set `do_sample=True` + or unset `temperature`. + + Every checkpoint save ultimately calls ``model.save_pretrained`` which + writes the generation config, so an unfixed model crashes the very first + save. We patch ``do_sample`` to ``True`` once, on the in-memory model + object, immediately after the trainer has been constructed. The fix is + local to this run — the upstream model files on the Hub are unchanged. + """ + model = getattr(trainer, "model", None) + if model is None: + return + gc = getattr(model, "generation_config", None) + if gc is None: + return + _FLOAT_SAMPLING_PARAMS = ( + "temperature", + "top_p", + "min_p", + "top_h", + "typical_p", + "epsilon_cutoff", + "eta_cutoff", + ) + has_sampling_param = any( + getattr(gc, p, None) is not None for p in _FLOAT_SAMPLING_PARAMS + ) or getattr(gc, "top_k", None) not in (None, 0) + if has_sampling_param and not getattr(gc, "do_sample", False): + logger.info( + "Sanitizing generation_config: setting do_sample=True so that " + "checkpoint saves can write generation_config.json without " + "tripping transformers' strict validation." + ) + gc.do_sample = True + + def build_callbacks(config: PostTrainingConfig, run_dir: Path) -> list: """Build the callback list (shared across methods).""" callbacks: list = [] diff --git a/src/post_training/methods/dpo.py b/src/post_training/methods/dpo.py index bcdb125..509064f 100644 --- a/src/post_training/methods/dpo.py +++ b/src/post_training/methods/dpo.py @@ -14,6 +14,7 @@ build_common_training_kwargs, build_model_init_kwargs, build_tokenizer, + sanitize_generation_config, ) if TYPE_CHECKING: @@ -56,7 +57,7 @@ def build_dpo_trainer(config: PostTrainingConfig, run_dir: Path) -> DPOTrainer: model_init_kwargs=build_model_init_kwargs(config), ) - return DPOTrainer( + trainer = DPOTrainer( model=config.model.name_or_path, ref_model=mc.ref_model_name_or_path, # None → TRL creates implicit copy processing_class=tokenizer, @@ -64,3 +65,5 @@ def build_dpo_trainer(config: PostTrainingConfig, run_dir: Path) -> DPOTrainer: args=dpo_config, callbacks=build_callbacks(config, run_dir), ) + sanitize_generation_config(trainer) + return trainer diff --git a/src/post_training/methods/sft.py b/src/post_training/methods/sft.py index a34aa68..7f57cbe 100644 --- a/src/post_training/methods/sft.py +++ b/src/post_training/methods/sft.py @@ -14,6 +14,7 @@ build_common_training_kwargs, build_model_init_kwargs, build_tokenizer, + sanitize_generation_config, ) if TYPE_CHECKING: @@ -55,10 +56,13 @@ def build_sft_trainer(config: PostTrainingConfig, run_dir: Path) -> SFTTrainer: model_init_kwargs=build_model_init_kwargs(config), ) - return SFTTrainer( + trainer = SFTTrainer( model=config.model.name_or_path, processing_class=tokenizer, train_dataset=dataset, args=sft_config, callbacks=build_callbacks(config, run_dir), ) + + sanitize_generation_config(trainer) + return trainer