From a627af887b9b0df82fe3f925ab8a897e12242a5a Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 28 Apr 2026 22:07:26 +0200 Subject: [PATCH 1/3] fix: sanitize generation config to prevent save_pretrained crash on OLMo-3 Think OLMo-3 Think models ship temperature/top_p with do_sample=False. transformers >= 5.x strict validation rejects this in GenerationConfig.save_pretrained, crashing every checkpoint save. Set do_sample=True in-memory on the trainer's model after construction. The upstream Hub files are unmodified; saved checkpoints preserve the model's recommended inference settings. --- src/post_training/methods/sft.py | 53 +++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/src/post_training/methods/sft.py b/src/post_training/methods/sft.py index a34aa68..14c7290 100644 --- a/src/post_training/methods/sft.py +++ b/src/post_training/methods/sft.py @@ -55,10 +55,61 @@ def build_sft_trainer(config: PostTrainingConfig, run_dir: Path) -> SFTTrainer: model_init_kwargs=build_model_init_kwargs(config), ) - return SFTTrainer( + trainer = SFTTrainer( model=config.model.name_or_path, processing_class=tokenizer, train_dataset=dataset, args=sft_config, callbacks=build_callbacks(config, run_dir), ) + + _sanitize_generation_config(trainer) + return trainer + + +def _sanitize_generation_config(trainer: SFTTrainer) -> None: + """Fix inconsistent ``generation_config`` so checkpoint saves don't fail. + + Some upstream models (notably Olmo-3 Think variants) ship a + ``generation_config.json`` that sets sampling-only parameters + (``temperature``, ``top_p``) while leaving ``do_sample=False``. This is + benign at training time — we never call ``model.generate`` — but + ``transformers >= 5.x`` runs strict validation inside + ``GenerationConfig.save_pretrained`` and refuses to write the file:: + + ValueError: GenerationConfig is invalid: + - `temperature` is set to 0.6 -- this flag is only used in + sample-based generation modes. You should set `do_sample=True` + or unset `temperature`. + + Every checkpoint save (HF Trainer's ``_save`` and our + ``InferenceCheckpointCallback``) ultimately calls ``model.save_pretrained`` + which writes the generation config, so an unfixed model crashes the + very first save. We patch ``do_sample`` to ``True`` once, on the + in-memory model object, immediately after the trainer (and therefore + the model) has been constructed. The fix is local to this run — the + upstream model files on the Hub are unchanged. + + AllenAI's open-instruct solves the same issue in ``model_utils.py`` + by setting ``temperature=None, top_p=None`` (stripping the params). + We instead set ``do_sample=True`` to preserve the upstream model's + recommended inference settings in the saved checkpoint. + """ + model = getattr(trainer, "model", None) + if model is None: + return + gc = getattr(model, "generation_config", None) + if gc is None: + return + has_sampling_param = ( + getattr(gc, "temperature", None) is not None + or getattr(gc, "top_p", None) is not None + or getattr(gc, "top_k", None) not in (None, 0) + ) + if has_sampling_param and not getattr(gc, "do_sample", False): + logger.info( + "Sanitizing generation_config: setting do_sample=True so that " + "checkpoint saves can write generation_config.json without " + "tripping transformers' strict validation." + ) + gc.do_sample = True From beddbdfd0a3d305ffdca29d8b5abf981bdfed426 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 16:16:53 +0200 Subject: [PATCH 2/3] fix: move sanitize_generation_config to common and apply to DPO --- src/post_training/methods/common.py | 41 +++++++++++++++++++++++ src/post_training/methods/dpo.py | 5 ++- src/post_training/methods/sft.py | 51 ++--------------------------- 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 4cf0db4..7fa2fb6 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -126,6 +126,47 @@ def build_common_training_kwargs( ) +def sanitize_generation_config(trainer: Any) -> None: + """Fix inconsistent ``generation_config`` so checkpoint saves don't fail. + + Some upstream models (notably Olmo-3 Think variants) ship a + ``generation_config.json`` that sets sampling-only parameters + (``temperature``, ``top_p``) while leaving ``do_sample=False``. This is + benign at training time — we never call ``model.generate`` — but + ``transformers >= 5.x`` runs strict validation inside + ``GenerationConfig.save_pretrained`` and refuses to write the file:: + + ValueError: GenerationConfig is invalid: + - `temperature` is set to 0.6 -- this flag is only used in + sample-based generation modes. You should set `do_sample=True` + or unset `temperature`. + + Every checkpoint save ultimately calls ``model.save_pretrained`` which + writes the generation config, so an unfixed model crashes the very first + save. We patch ``do_sample`` to ``True`` once, on the in-memory model + object, immediately after the trainer has been constructed. The fix is + local to this run — the upstream model files on the Hub are unchanged. + """ + model = getattr(trainer, "model", None) + if model is None: + return + gc = getattr(model, "generation_config", None) + if gc is None: + return + has_sampling_param = ( + getattr(gc, "temperature", None) is not None + or getattr(gc, "top_p", None) is not None + or getattr(gc, "top_k", None) not in (None, 0) + ) + if has_sampling_param and not getattr(gc, "do_sample", False): + logger.info( + "Sanitizing generation_config: setting do_sample=True so that " + "checkpoint saves can write generation_config.json without " + "tripping transformers' strict validation." + ) + gc.do_sample = True + + def build_callbacks(config: PostTrainingConfig, run_dir: Path) -> list: """Build the callback list (shared across methods).""" callbacks: list = [] diff --git a/src/post_training/methods/dpo.py b/src/post_training/methods/dpo.py index bcdb125..509064f 100644 --- a/src/post_training/methods/dpo.py +++ b/src/post_training/methods/dpo.py @@ -14,6 +14,7 @@ build_common_training_kwargs, build_model_init_kwargs, build_tokenizer, + sanitize_generation_config, ) if TYPE_CHECKING: @@ -56,7 +57,7 @@ def build_dpo_trainer(config: PostTrainingConfig, run_dir: Path) -> DPOTrainer: model_init_kwargs=build_model_init_kwargs(config), ) - return DPOTrainer( + trainer = DPOTrainer( model=config.model.name_or_path, ref_model=mc.ref_model_name_or_path, # None → TRL creates implicit copy processing_class=tokenizer, @@ -64,3 +65,5 @@ def build_dpo_trainer(config: PostTrainingConfig, run_dir: Path) -> DPOTrainer: args=dpo_config, callbacks=build_callbacks(config, run_dir), ) + sanitize_generation_config(trainer) + return trainer diff --git a/src/post_training/methods/sft.py b/src/post_training/methods/sft.py index 14c7290..7f57cbe 100644 --- a/src/post_training/methods/sft.py +++ b/src/post_training/methods/sft.py @@ -14,6 +14,7 @@ build_common_training_kwargs, build_model_init_kwargs, build_tokenizer, + sanitize_generation_config, ) if TYPE_CHECKING: @@ -63,53 +64,5 @@ def build_sft_trainer(config: PostTrainingConfig, run_dir: Path) -> SFTTrainer: callbacks=build_callbacks(config, run_dir), ) - _sanitize_generation_config(trainer) + sanitize_generation_config(trainer) return trainer - - -def _sanitize_generation_config(trainer: SFTTrainer) -> None: - """Fix inconsistent ``generation_config`` so checkpoint saves don't fail. - - Some upstream models (notably Olmo-3 Think variants) ship a - ``generation_config.json`` that sets sampling-only parameters - (``temperature``, ``top_p``) while leaving ``do_sample=False``. This is - benign at training time — we never call ``model.generate`` — but - ``transformers >= 5.x`` runs strict validation inside - ``GenerationConfig.save_pretrained`` and refuses to write the file:: - - ValueError: GenerationConfig is invalid: - - `temperature` is set to 0.6 -- this flag is only used in - sample-based generation modes. You should set `do_sample=True` - or unset `temperature`. - - Every checkpoint save (HF Trainer's ``_save`` and our - ``InferenceCheckpointCallback``) ultimately calls ``model.save_pretrained`` - which writes the generation config, so an unfixed model crashes the - very first save. We patch ``do_sample`` to ``True`` once, on the - in-memory model object, immediately after the trainer (and therefore - the model) has been constructed. The fix is local to this run — the - upstream model files on the Hub are unchanged. - - AllenAI's open-instruct solves the same issue in ``model_utils.py`` - by setting ``temperature=None, top_p=None`` (stripping the params). - We instead set ``do_sample=True`` to preserve the upstream model's - recommended inference settings in the saved checkpoint. - """ - model = getattr(trainer, "model", None) - if model is None: - return - gc = getattr(model, "generation_config", None) - if gc is None: - return - has_sampling_param = ( - getattr(gc, "temperature", None) is not None - or getattr(gc, "top_p", None) is not None - or getattr(gc, "top_k", None) not in (None, 0) - ) - if has_sampling_param and not getattr(gc, "do_sample", False): - logger.info( - "Sanitizing generation_config: setting do_sample=True so that " - "checkpoint saves can write generation_config.json without " - "tripping transformers' strict validation." - ) - gc.do_sample = True From 8c161b0f59ae0a59aef5aaddf038d26f888729a0 Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Thu, 30 Apr 2026 17:31:31 +0200 Subject: [PATCH 3/3] fix: extend sampling-param heuristic to all 8 HF validator params --- src/post_training/methods/common.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/post_training/methods/common.py b/src/post_training/methods/common.py index 7fa2fb6..75a95ba 100644 --- a/src/post_training/methods/common.py +++ b/src/post_training/methods/common.py @@ -153,11 +153,18 @@ def sanitize_generation_config(trainer: Any) -> None: gc = getattr(model, "generation_config", None) if gc is None: return - has_sampling_param = ( - getattr(gc, "temperature", None) is not None - or getattr(gc, "top_p", None) is not None - or getattr(gc, "top_k", None) not in (None, 0) + _FLOAT_SAMPLING_PARAMS = ( + "temperature", + "top_p", + "min_p", + "top_h", + "typical_p", + "epsilon_cutoff", + "eta_cutoff", ) + has_sampling_param = any( + getattr(gc, p, None) is not None for p in _FLOAT_SAMPLING_PARAMS + ) or getattr(gc, "top_k", None) not in (None, 0) if has_sampling_param and not getattr(gc, "do_sample", False): logger.info( "Sanitizing generation_config: setting do_sample=True so that "