cp: fix: local checkpoint integration (2323) into r0.3.0#2709
Conversation
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
|
/ok to test 1068b27 |
📝 WalkthroughWalkthroughChanges enhance local (non-persistent) checkpoint handling by embedding TrainState metadata into saved state, diverging LOCAL and non-LOCAL checkpoint save/load flows, and updating test infrastructure to support and validate local checkpoint functionality in distributed training scenarios. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/functional_tests/training/test_nvrx_straggler.py (1)
180-192:⚠️ Potential issue | 🟡 MinorDocstring and code mismatch: rank 0 sleeps, not rank 1.
The docstring states "Only rank 1 will be slow to simulate a straggler scenario" but the code sleeps on
rank == 0. This inconsistency could confuse future maintainers.📝 Proposed fix
def create_timed_forward_step_func(sleep_time: float = 1.0): """Create a forward step function that sleeps before calling the real forward_step. This simulates work being done and allows NVRx to measure performance differences. - Only rank 1 will be slow to simulate a straggler scenario. + Only rank 0 will be slow to simulate a straggler scenario. Args: sleep_time: Time to sleep in seconds before each forward step (only for rank 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/functional_tests/training/test_nvrx_straggler.py` around lines 180 - 192, The docstring says rank 1 should be the straggler but timed_forward_step_func currently checks torch.distributed.get_rank() == 0; update the condition in timed_forward_step_func to check for rank 1 (torch.distributed.get_rank() == 1) so the sleep and print occur on the intended rank, and/or update the docstring to reflect whichever behavior you choose to keep; ensure references to torch.distributed.is_initialized(), torch.distributed.get_rank(), and the sleep/print lines in timed_forward_step_func are consistent.
🧹 Nitpick comments (3)
tests/functional_tests/training/test_nvrx_straggler.py (3)
262-270: Catching blindExceptionloses error context for debugging.The ruff warning (BLE001) flags catching a bare
Exception. While this ensures the test reports success/failure cleanly, it may swallow unexpected errors making debugging harder. Consider catching a more specific exception or re-raising after logging.♻️ Proposed fix
try: pretrain(config=config, forward_step_func=forward_step_func) training_success = True - except Exception: + except Exception as e: training_success = False if rank == 0: import traceback - traceback.print_exc() + print(f"Training failed with: {e}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/functional_tests/training/test_nvrx_straggler.py` around lines 262 - 270, The test currently catches a bare Exception around the pretrain(...) call which hides error details; update the try/except to either catch a more specific exception type thrown by pretrain (or the underlying error class used in pretrain/forward_step_func) or re-raise after logging so test failures preserve tracebacks. Specifically, modify the except block that surrounds pretrain(config=config, forward_step_func=forward_step_func) (and sets training_success) to catch the concrete error(s) or to do logging using traceback.print_exc() and then raise the exception again (preserving the existing rank == 0 conditional logging for traceback).
256-256:_TeeWriteris used before its class definition.
_TeeWriteris instantiated on line 256 but defined on line 309. While Python allows forward references within the same file at runtime (since both are in the same module scope), this ordering reduces readability. Consider moving the_TeeWriterclass definition above its first usage, alongside the other helper classes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/functional_tests/training/test_nvrx_straggler.py` at line 256, Move the _TeeWriter class definition so it appears before its first instantiation (currently sys.stdout = _TeeWriter(old_stdout, captured_stdout)), ideally grouping it with the other helper classes near the top of the file; update the file so that _TeeWriter is defined before any code that constructs it to improve readability and avoid forward-reference confusion.
15-15: Shebang line should precede the copyright header.The shebang (
#!/usr/bin/env python3) on line 15 should be at the very top of the file (line 1), before the copyright notice, for proper script execution when invoked directly.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/functional_tests/training/test_nvrx_straggler.py` at line 15, Move the shebang "#!/usr/bin/env python3" to the very top of the file so it appears before any header or copyright text; locate the current shebang occurrence and reposition it as the first line of the file to ensure correct interpreter selection when the script is executed directly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tests/functional_tests/training/test_nvrx_straggler.py`:
- Around line 180-192: The docstring says rank 1 should be the straggler but
timed_forward_step_func currently checks torch.distributed.get_rank() == 0;
update the condition in timed_forward_step_func to check for rank 1
(torch.distributed.get_rank() == 1) so the sleep and print occur on the intended
rank, and/or update the docstring to reflect whichever behavior you choose to
keep; ensure references to torch.distributed.is_initialized(),
torch.distributed.get_rank(), and the sleep/print lines in
timed_forward_step_func are consistent.
---
Nitpick comments:
In `@tests/functional_tests/training/test_nvrx_straggler.py`:
- Around line 262-270: The test currently catches a bare Exception around the
pretrain(...) call which hides error details; update the try/except to either
catch a more specific exception type thrown by pretrain (or the underlying error
class used in pretrain/forward_step_func) or re-raise after logging so test
failures preserve tracebacks. Specifically, modify the except block that
surrounds pretrain(config=config, forward_step_func=forward_step_func) (and sets
training_success) to catch the concrete error(s) or to do logging using
traceback.print_exc() and then raise the exception again (preserving the
existing rank == 0 conditional logging for traceback).
- Line 256: Move the _TeeWriter class definition so it appears before its first
instantiation (currently sys.stdout = _TeeWriter(old_stdout, captured_stdout)),
ideally grouping it with the other helper classes near the top of the file;
update the file so that _TeeWriter is defined before any code that constructs it
to improve readability and avoid forward-reference confusion.
- Line 15: Move the shebang "#!/usr/bin/env python3" to the very top of the file
so it appears before any header or copyright text; locate the current shebang
occurrence and reposition it as the first line of the file to ensure correct
interpreter selection when the script is executed directly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4802ba98-5074-4748-acf9-1a8849d9bd36
📒 Files selected for processing (5)
src/megatron/bridge/training/checkpointing.pysrc/megatron/bridge/training/setup.pytests/functional_tests/conftest.pytests/functional_tests/training/test_local_checkpointing.pytests/functional_tests/training/test_nvrx_straggler.py
beep boop [🤖]: Hi @ananthsub 👋,
Summary by CodeRabbit
Release Notes
New Features
Tests