From 024c8bd5e8c63a5bab6df6a0727fd25e9c09506e Mon Sep 17 00:00:00 2001 From: Arjun Krishnakumar Date: Tue, 28 Apr 2026 22:06:51 +0200 Subject: [PATCH] feat: wrap dataset loading in main_process_first to prevent Lustre cache races On multi-node runs, all ranks race to build the HF datasets cache simultaneously on Lustre, causing corruption or crashes. Wrapping load_and_mix_datasets() in PartialState().main_process_first() lets rank 0 populate the cache first; all other ranks then read from it. --- src/post_training/methods/dpo.py | 4 +++- src/post_training/methods/sft.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/post_training/methods/dpo.py b/src/post_training/methods/dpo.py index bcdb125..09c9206 100644 --- a/src/post_training/methods/dpo.py +++ b/src/post_training/methods/dpo.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from accelerate import PartialState from trl import DPOConfig, DPOTrainer from post_training.data.loader import load_and_mix_datasets @@ -45,7 +46,8 @@ def build_dpo_trainer(config: PostTrainingConfig, run_dir: Path) -> DPOTrainer: mc = config.dpo # method-specific config tokenizer = build_tokenizer(config) - dataset = load_and_mix_datasets(config.data, row_filter=_dpo_row_filter) + with PartialState().main_process_first(): + dataset = load_and_mix_datasets(config.data, row_filter=_dpo_row_filter) dpo_config = DPOConfig( **build_common_training_kwargs(config, run_dir), diff --git a/src/post_training/methods/sft.py b/src/post_training/methods/sft.py index a34aa68..265c700 100644 --- a/src/post_training/methods/sft.py +++ b/src/post_training/methods/sft.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from accelerate import PartialState from trl import SFTConfig, SFTTrainer from post_training.data.loader import load_and_mix_datasets @@ -45,7 +46,8 @@ def build_sft_trainer(config: PostTrainingConfig, run_dir: Path) -> SFTTrainer: mc = config.sft # method-specific config tokenizer = build_tokenizer(config) - dataset = load_and_mix_datasets(config.data, row_filter=_sft_row_filter) + with PartialState().main_process_first(): + dataset = load_and_mix_datasets(config.data, row_filter=_sft_row_filter) sft_config = SFTConfig( **build_common_training_kwargs(config, run_dir),