Make checkpointing better#1647
Conversation
…-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…-By: Claude Opus 4.7 <noreply@anthropic.com>
… of $@ Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…g size/bandwidth Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…icy rewrites, RNG dup Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…UDA RNG Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ore driver_state guard Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…land on update steps Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request enhances the checkpointing system for GRPO training by introducing driver state persistence (dataloader and actor states) and refactoring reference policy updates. It also adds a training script for Qwen 2.5 0.5B. Key feedback includes addressing a potential backward compatibility issue with RNG state restoration, improving consistency in checkpoint directory identification by using the 'latest' file, and mitigating performance bottlenecks caused by synchronous file system operations during checkpoint logging.
| if torch.cuda.is_available(): | ||
| torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"]) |
There was a problem hiding this comment.
The code now assumes torch_cuda_rng_state_all is always present in rng_states when CUDA is available. This will cause a KeyError when attempting to resume from older checkpoints that lack this specific key. It is safer to check for the key's existence to maintain backward compatibility.
| if torch.cuda.is_available(): | |
| torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"]) | |
| if torch.cuda.is_available() and "torch_cuda_rng_state_all" in rng_states: | |
| torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"]) |
| step_dirs = [ | ||
| d | ||
| for d in os.listdir(args.checkpoint_state_dir) | ||
| if d.startswith("global_step") | ||
| and d[len("global_step") :].isdigit() | ||
| and os.path.isdir(os.path.join(args.checkpoint_state_dir, d)) | ||
| ] | ||
| if step_dirs: | ||
| latest_dir = os.path.join( | ||
| args.checkpoint_state_dir, max(step_dirs, key=lambda d: int(d[len("global_step") :])) | ||
| ) |
There was a problem hiding this comment.
Manually searching for the latest checkpoint directory by parsing global_step* names is redundant and potentially inconsistent with DeepSpeed's latest file mechanism. Since save_driver_state (line 858) already relies on the latest file as the source of truth, this function should ideally do the same to ensure it measures the correct directory.
latest_file = os.path.join(args.checkpoint_state_dir, "latest")
latest_dir = None
if os.path.exists(latest_file):
with open(latest_file) as f:
latest_dir = os.path.join(args.checkpoint_state_dir, f.read().strip())
if latest_dir:| total_bytes = sum( | ||
| os.path.getsize(os.path.join(root, f)) for root, _, files in os.walk(latest_dir) for f in files | ||
| ) |
There was a problem hiding this comment.
Calculating the total checkpoint size using os.walk and os.path.getsize on the main thread can introduce significant latency, especially when using distributed filesystems or when dealing with large model checkpoints (e.g., 70B+ parameters). Since this information is only used for logging, consider moving this calculation to a background thread or making it optional to avoid stalling the training loop.
…pu/save dance Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…y@anthropic.com>
Summary
Pull the checkpoint-state save into the timed save path and cut redundant work along the way.
run_trainingand into a newmaybe_save_checkpoint_statehelper called fromone_training_step, so its duration counts toward thetime/savingmetric andnum_total_tokensreflects the just-finished step.global_step{N}directory after each save, for I/O monitoring.mp_rank_*_model_states.pt. Rank 0 now writes a singledriver_state.ptnext to the DeepSpeed checkpoint; the load path picks it up and merges intostates.should_save_ref_policy(args, training_step)predicate and use it both at the EMA-update site and the save site, so off-cadence checkpoints don't rewrite an unchanged ref-policy file. An assert inmain()enforcescheckpoint_state_freq % ref_policy_update_freq == 0so saves always land on update steps.torch_cuda_rng_statesdict;torch_cuda_rng_state_allalready covers every device.state_dict) and the mpu detach / all-rankssave_checkpointcontract.No changes to checkpoint format on the load side beyond reading
driver_state.ptwhen present.Test plan
--checkpoint_state_freqset and confirm the new size/bandwidth log line appears and thattime/savingreflects the checkpoint-state save.driver_state.pt.load_ref_policy=Trueand a configuredref_policy_update_freq(multiple ofcheckpoint_state_freq) that the ref-policy file is only rewritten on update steps.