Skip to content

cp: fix: local checkpoint integration (2323) into r0.3.0#2709

Merged
ko3n1g merged 2 commits intor0.3.0from
cherry-pick-2323-r0.3.0
Mar 16, 2026
Merged

cp: fix: local checkpoint integration (2323) into r0.3.0#2709
ko3n1g merged 2 commits intor0.3.0from
cherry-pick-2323-r0.3.0

Conversation

@svcnvidia-nemo-ci
Copy link
Copy Markdown
Contributor

@svcnvidia-nemo-ci svcnvidia-nemo-ci commented Mar 9, 2026

beep boop [🤖]: Hi @ananthsub 👋,

we've cherry picked #2323 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced local checkpoint functionality to properly preserve and restore training state during resume operations.
  • Tests

    • Added comprehensive test coverage for local checkpoint save and resume workflows, including multi-iteration training scenarios.
    • Improved test infrastructure with automatic CUDA memory cleanup and in-memory logging utilities.

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@svcnvidia-nemo-ci
Copy link
Copy Markdown
Contributor Author

/ok to test 1068b27

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 9, 2026

📝 Walkthrough

Walkthrough

Changes enhance local (non-persistent) checkpoint handling by embedding TrainState metadata into saved state, diverging LOCAL and non-LOCAL checkpoint save/load flows, and updating test infrastructure to support and validate local checkpoint functionality in distributed training scenarios.

Changes

Cohort / File(s) Summary
Core Checkpointing Logic
src/megatron/bridge/training/checkpointing.py, src/megatron/bridge/training/setup.py
Embeds TrainState metadata into local checkpoints to preserve counters across resume. LOCAL checkpoints bypass wandb/mlflow callbacks and non-persistent cleanup. Save and load paths diverge: LOCAL checkpoints derive run_config from current config rather than reading files, use local-specific metadata builders, and apply optimizer loading under no_grad. Optimizes non-persistent LOCAL loading by early return for rank 0. Integrates local checkpoint detection into setup to enable loading when no global checkpoint exists.
Test Infrastructure
tests/functional_tests/conftest.py
Modifies reset_cuda fixture to autouse=True for automatic CUDA reset across all tests. Adds explicit garbage collection (gc.collect()) within CUDA cleanup path prior to cache emptying and synchronization.
Functional Tests
tests/functional_tests/training/test_local_checkpointing.py
New test module for local checkpoint save/resume validation. Introduces Llama3ModelProvider145M configuration, TrainStateAssertCallback for progress tracking, and factory for ConfigContainer assembly. Implements two GPU-only tests verifying resume from step 5 to 10 with and without most_recent_k cleanup, validating step counters and consumed_train_samples persistence.
Straggler Detection Test Refactoring
tests/functional_tests/training/test_nvrx_straggler.py
Replaces file-based logging with in-memory handler approach. Introduces _InMemoryHandler for log capture, _TeeWriter for stdout duplication, and attachment/detachment utilities. Adds conditional skip via pytest.mark.skipif when NVRx is unavailable. Modernizes test to validate logs and stdout contents directly instead of reading per-rank files.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains major changes to checkpointing infrastructure but PR description lacks test results, performance metrics, or convergence validation. Add test execution results, convergence validation, and performance impact measurements to PR description, or reference original PR #2323 test results.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main purpose: cherry-picking local checkpoint integration fixes (issue #2323) into the r0.3.0 branch. It's specific and directly related to the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cherry-pick-2323-r0.3.0

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/functional_tests/training/test_nvrx_straggler.py (1)

180-192: ⚠️ Potential issue | 🟡 Minor

Docstring and code mismatch: rank 0 sleeps, not rank 1.

The docstring states "Only rank 1 will be slow to simulate a straggler scenario" but the code sleeps on rank == 0. This inconsistency could confuse future maintainers.

📝 Proposed fix
 def create_timed_forward_step_func(sleep_time: float = 1.0):
     """Create a forward step function that sleeps before calling the real forward_step.
 
     This simulates work being done and allows NVRx to measure performance differences.
-    Only rank 1 will be slow to simulate a straggler scenario.
+    Only rank 0 will be slow to simulate a straggler scenario.
 
     Args:
         sleep_time: Time to sleep in seconds before each forward step (only for rank 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/training/test_nvrx_straggler.py` around lines 180 -
192, The docstring says rank 1 should be the straggler but
timed_forward_step_func currently checks torch.distributed.get_rank() == 0;
update the condition in timed_forward_step_func to check for rank 1
(torch.distributed.get_rank() == 1) so the sleep and print occur on the intended
rank, and/or update the docstring to reflect whichever behavior you choose to
keep; ensure references to torch.distributed.is_initialized(),
torch.distributed.get_rank(), and the sleep/print lines in
timed_forward_step_func are consistent.
🧹 Nitpick comments (3)
tests/functional_tests/training/test_nvrx_straggler.py (3)

262-270: Catching blind Exception loses error context for debugging.

The ruff warning (BLE001) flags catching a bare Exception. While this ensures the test reports success/failure cleanly, it may swallow unexpected errors making debugging harder. Consider catching a more specific exception or re-raising after logging.

♻️ Proposed fix
         try:
             pretrain(config=config, forward_step_func=forward_step_func)
             training_success = True
-        except Exception:
+        except Exception as e:
             training_success = False
             if rank == 0:
                 import traceback
-
                 traceback.print_exc()
+                print(f"Training failed with: {e}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/training/test_nvrx_straggler.py` around lines 262 -
270, The test currently catches a bare Exception around the pretrain(...) call
which hides error details; update the try/except to either catch a more specific
exception type thrown by pretrain (or the underlying error class used in
pretrain/forward_step_func) or re-raise after logging so test failures preserve
tracebacks. Specifically, modify the except block that surrounds
pretrain(config=config, forward_step_func=forward_step_func) (and sets
training_success) to catch the concrete error(s) or to do logging using
traceback.print_exc() and then raise the exception again (preserving the
existing rank == 0 conditional logging for traceback).

256-256: _TeeWriter is used before its class definition.

_TeeWriter is instantiated on line 256 but defined on line 309. While Python allows forward references within the same file at runtime (since both are in the same module scope), this ordering reduces readability. Consider moving the _TeeWriter class definition above its first usage, alongside the other helper classes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/training/test_nvrx_straggler.py` at line 256, Move the
_TeeWriter class definition so it appears before its first instantiation
(currently sys.stdout = _TeeWriter(old_stdout, captured_stdout)), ideally
grouping it with the other helper classes near the top of the file; update the
file so that _TeeWriter is defined before any code that constructs it to improve
readability and avoid forward-reference confusion.

15-15: Shebang line should precede the copyright header.

The shebang (#!/usr/bin/env python3) on line 15 should be at the very top of the file (line 1), before the copyright notice, for proper script execution when invoked directly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/training/test_nvrx_straggler.py` at line 15, Move the
shebang "#!/usr/bin/env python3" to the very top of the file so it appears
before any header or copyright text; locate the current shebang occurrence and
reposition it as the first line of the file to ensure correct interpreter
selection when the script is executed directly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/functional_tests/training/test_nvrx_straggler.py`:
- Around line 180-192: The docstring says rank 1 should be the straggler but
timed_forward_step_func currently checks torch.distributed.get_rank() == 0;
update the condition in timed_forward_step_func to check for rank 1
(torch.distributed.get_rank() == 1) so the sleep and print occur on the intended
rank, and/or update the docstring to reflect whichever behavior you choose to
keep; ensure references to torch.distributed.is_initialized(),
torch.distributed.get_rank(), and the sleep/print lines in
timed_forward_step_func are consistent.

---

Nitpick comments:
In `@tests/functional_tests/training/test_nvrx_straggler.py`:
- Around line 262-270: The test currently catches a bare Exception around the
pretrain(...) call which hides error details; update the try/except to either
catch a more specific exception type thrown by pretrain (or the underlying error
class used in pretrain/forward_step_func) or re-raise after logging so test
failures preserve tracebacks. Specifically, modify the except block that
surrounds pretrain(config=config, forward_step_func=forward_step_func) (and sets
training_success) to catch the concrete error(s) or to do logging using
traceback.print_exc() and then raise the exception again (preserving the
existing rank == 0 conditional logging for traceback).
- Line 256: Move the _TeeWriter class definition so it appears before its first
instantiation (currently sys.stdout = _TeeWriter(old_stdout, captured_stdout)),
ideally grouping it with the other helper classes near the top of the file;
update the file so that _TeeWriter is defined before any code that constructs it
to improve readability and avoid forward-reference confusion.
- Line 15: Move the shebang "#!/usr/bin/env python3" to the very top of the file
so it appears before any header or copyright text; locate the current shebang
occurrence and reposition it as the first line of the file to ensure correct
interpreter selection when the script is executed directly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4802ba98-5074-4748-acf9-1a8849d9bd36

📥 Commits

Reviewing files that changed from the base of the PR and between bce688d and 1068b27.

📒 Files selected for processing (5)
  • src/megatron/bridge/training/checkpointing.py
  • src/megatron/bridge/training/setup.py
  • tests/functional_tests/conftest.py
  • tests/functional_tests/training/test_local_checkpointing.py
  • tests/functional_tests/training/test_nvrx_straggler.py

ananthsub
ananthsub previously approved these changes Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration cherry-pick Run CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants