Skip to content

Make checkpointing better#1647

Draft
finbarrtimbers wants to merge 11 commits into
mainfrom
finbarr/check-script-in
Draft

Make checkpointing better#1647
finbarrtimbers wants to merge 11 commits into
mainfrom
finbarr/check-script-in

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Apr 29, 2026

Summary

Pull the checkpoint-state save into the timed save path and cut redundant work along the way.

  • Move the inline checkpoint-state save out of run_training and into a new maybe_save_checkpoint_state helper called from one_training_step, so its duration counts toward the time/saving metric and num_total_tokens reflects the just-finished step.
  • Log checkpoint size (GiB) and average write bandwidth (MiB/s) for the most recent global_step{N} directory after each save, for I/O monitoring.
  • Stop pickling the dataloader and data-prep-actor state into every rank's mp_rank_*_model_states.pt. Rank 0 now writes a single driver_state.pt next to the DeepSpeed checkpoint; the load path picks it up and merges into states.
  • Skip-when-clean for the reference policy: introduce a shared should_save_ref_policy(args, training_step) predicate and use it both at the EMA-update site and the save site, so off-cadence checkpoints don't rewrite an unchanged ref-policy file. An assert in main() enforces checkpoint_state_freq % ref_policy_update_freq == 0 so saves always land on update steps.
  • Drop the redundant per-device torch_cuda_rng_states dict; torch_cuda_rng_state_all already covers every device.
  • Restore short WHY comments on the two non-obvious bits: ref_policy bypassing DeepSpeed's saver (DummyOptim has no state_dict) and the mpu detach / all-ranks save_checkpoint contract.

No changes to checkpoint format on the load side beyond reading driver_state.pt when present.

Test plan

  • Run a debug GRPO script with --checkpoint_state_freq set and confirm the new size/bandwidth log line appears and that time/saving reflects the checkpoint-state save.
  • Resume from a fresh checkpoint and verify dataloader / data-prep-actor state restore via driver_state.pt.
  • Confirm with load_ref_policy=True and a configured ref_policy_update_freq (multiple of checkpoint_state_freq) that the ref-policy file is only rewritten on update steps.

…-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…-By: Claude Opus 4.7 <noreply@anthropic.com>
… of $@ Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…g size/bandwidth Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…icy rewrites, RNG dup Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…UDA RNG Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ore driver_state guard Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…land on update steps Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the checkpointing system for GRPO training by introducing driver state persistence (dataloader and actor states) and refactoring reference policy updates. It also adds a training script for Qwen 2.5 0.5B. Key feedback includes addressing a potential backward compatibility issue with RNG state restoration, improving consistency in checkpoint directory identification by using the 'latest' file, and mitigating performance bottlenecks caused by synchronous file system operations during checkpoint logging.

Comment on lines +436 to +437
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code now assumes torch_cuda_rng_state_all is always present in rng_states when CUDA is available. This will cause a KeyError when attempting to resume from older checkpoints that lack this specific key. It is safer to check for the key's existence to maintain backward compatibility.

Suggested change
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"])
if torch.cuda.is_available() and "torch_cuda_rng_state_all" in rng_states:
torch.cuda.set_rng_state_all(rng_states["torch_cuda_rng_state_all"])

Comment on lines +1809 to +1819
step_dirs = [
d
for d in os.listdir(args.checkpoint_state_dir)
if d.startswith("global_step")
and d[len("global_step") :].isdigit()
and os.path.isdir(os.path.join(args.checkpoint_state_dir, d))
]
if step_dirs:
latest_dir = os.path.join(
args.checkpoint_state_dir, max(step_dirs, key=lambda d: int(d[len("global_step") :]))
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Manually searching for the latest checkpoint directory by parsing global_step* names is redundant and potentially inconsistent with DeepSpeed's latest file mechanism. Since save_driver_state (line 858) already relies on the latest file as the source of truth, this function should ideally do the same to ensure it measures the correct directory.

    latest_file = os.path.join(args.checkpoint_state_dir, "latest")
    latest_dir = None
    if os.path.exists(latest_file):
        with open(latest_file) as f:
            latest_dir = os.path.join(args.checkpoint_state_dir, f.read().strip())
    if latest_dir:

Comment on lines +1820 to +1822
total_bytes = sum(
os.path.getsize(os.path.join(root, f)) for root, _, files in os.walk(latest_dir) for f in files
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calculating the total checkpoint size using os.walk and os.path.getsize on the main thread can introduce significant latency, especially when using distributed filesystems or when dealing with large model checkpoints (e.g., 70B+ parameters). Since this information is only used for logging, consider moving this calculation to a background thread or making it optional to avoid stalling the training loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant