Skip to content

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Sep 15, 2025

What does this PR do ?

Support virtual pipeline parallel (vpp) in mcore

Issues

closes #1038

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Added multi-model support for training, inference, checkpoints, and exports.
    • Sequence packing now adapts to data/pipeline parallelism with configurable bin counts for improved throughput.
    • Improved handling for virtual pipeline stages with per-stage data iterators.
  • Refactor
    • Internal model handling updated to operate on multiple sub-models consistently across workflows.
  • Chores
    • Updated Megatron-Bridge workspace submodule (no user-visible changes).

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 15, 2025

📝 Walkthrough

Walkthrough

Updates include a submodule pointer bump, introduction of two optional sequence packing parameters, pipeline-parallel-aware binning in LM policy training, broad refactor of MegatronPolicyWorker to handle a list of models (multi-model/pipeline/virtual-pipeline aware), and a corresponding utility change to read modules from the first model in the list.

Changes

Cohort / File(s) Summary
Submodule pointer update
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
Submodule reference updated from abd52c89... to 30f24c66.... No code or API changes in this repo.
Sequence packing args extension
nemo_rl/distributed/batched_data_dict.py
Added optional fields to SequencePackingArgs: min_bin_count, bin_count_multiple. shard_by_batch_size passes these to the bin packer, defaulting to shard count when falsy.
PP-aware sequence packing config
nemo_rl/models/policy/lm_policy.py
Computes pp_size; when sequence packing is used, sets min_bin_count and bin_count_multiple to dp_size * pp_size. No changes to dynamic batching path.
Multi-model worker refactor and pipeline data flow
nemo_rl/models/policy/megatron_policy_worker.py
Treats self.model as a list; updates hooks, training, device moves, checkpointing, exports, parameter updates, and caches to iterate per sub-model. Adjusts data iterator to list when virtual_pipeline_model_parallel_size > 1. Configures overlap of P2P comm for multi-stage pipelines. Adds rank-aware/log prints. No public API signature changes.
Utils aligned to model list
nemo_rl/models/policy/utils.py
get_gpu_info now inspects modules from model[0].named_modules() instead of model.named_modules(). Return structure unchanged.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Trainer
  participant LMPolicy
  participant ParallelConfig as Parallel Config
  participant SeqPack as Sequence Packing Args
  participant BinPacker

  Trainer->>LMPolicy: start_training()
  LMPolicy->>ParallelConfig: get dp_size, pp_size
  ParallelConfig-->>LMPolicy: dp_size, pp_size
  LMPolicy->>SeqPack: set min_bin_count = dp_size*pp_size\nset bin_count_multiple = dp_size*pp_size
  LMPolicy->>BinPacker: construct(..., min_bin_count, bin_count_multiple)
  BinPacker-->>LMPolicy: ready
  LMPolicy-->>Trainer: proceed with training
Loading
sequenceDiagram
  autonumber
  actor Orchestrator
  participant Worker as MegatronPolicyWorker
  participant Models as [Model_0, Model_1, ...]
  participant DataIt as Data Iterator(s)
  participant Checkpoint as Checkpoint/Export

  Orchestrator->>Worker: initialize(models=list)
  Worker->>Models: for each: move to device, set hooks
  alt virtual_pipeline_model_parallel_size > 1
    Worker->>DataIt: create list of iterators (per VP stage)
  else
    Worker->>DataIt: create single iterator
  end
  loop training steps
    Worker->>Models: for each: zero grads, reset caches
    Worker->>Models: forward/backward (passes list to step fn)
  end
  Worker->>Checkpoint: save(model=list), export weights
  Worker-->>Orchestrator: done
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I thump with glee on parallel trails,
Bins now hop in DP×PP scales.
A warren of models, lined in a row,
Each takes a nibble, then onward we go.
Pack, iterate, checkpoint—done!
Carrots aligned, workloads run.
Happy hops under the training sun.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.29% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'perf: Support mcore virtual pipeline parallel' is partially related to the changeset. While virtual pipeline parallel support is a significant aspect of the changes (particularly in megatron_policy_worker.py), the title uses 'perf:' prefix suggesting performance optimization, but the changes are primarily feature additions (multi-model support, new optional fields, pipeline parallel awareness) rather than performance improvements. The title captures one aspect but undersells the broader scope of the changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
nemo_rl/models/policy/utils.py (1)

15-18: get_gpu_info should accept single model or list of models.

Using model[0] assumes indexability and breaks dtensor paths that pass a single nn.Module. Make it robust and update typing.

-from typing import Any, Dict
+from typing import Any, Dict, Sequence
-def get_gpu_info(model: torch.nn.Module) -> dict[str, Any]:
+def get_gpu_info(model: torch.nn.Module | Sequence[torch.nn.Module]) -> dict[str, Any]:
-    for module_name, module in model[0].named_modules():
+    first_model = model[0] if isinstance(model, (list, tuple)) else model
+    for module_name, module in first_model.named_modules():

Also applies to: 103-104, 134-146

nemo_rl/models/policy/megatron_policy_worker.py (2)

1312-1376: use_reference_model is broken with list-of-models (loads single state_dict into list).

self.model is now a list; this context manager calls self.model.state_dict(), which will crash, and reference_state_dict holds only rank-0 chunk. Load/save must iterate all sub-models.

Apply:

-            try:
-                # Save original references
-                model_state_dict = {}
-                for name, item in self.model.state_dict().items():
-                    if isinstance(item, torch.Tensor):
-                        item = item.detach().to(
-                            device="cpu", non_blocking=True, copy=True
-                        )
-                    model_state_dict[name] = item
-
-                self.model.load_state_dict(self.reference_state_dict, strict=True)
+            try:
+                # Save original references for all sub-models
+                model_state_dicts = []
+                for m in self.model:
+                    sd = {}
+                    for name, item in m.state_dict().items():
+                        if isinstance(item, torch.Tensor):
+                            item = item.detach().to(device="cpu", non_blocking=True, copy=True)
+                        sd[name] = item
+                    model_state_dicts.append(sd)
+
+                # Load per-submodel reference weights
+                assert hasattr(self, "reference_state_dicts"), "reference_state_dicts missing"
+                for m, ref_sd in zip(self.model, self.reference_state_dicts):
+                    m.load_state_dict(ref_sd, strict=True)
@@
-                # Restore original references and device placement
-                self.model.load_state_dict(model_state_dict, strict=True)
+                # Restore original references per sub-model
+                for m, sd in zip(self.model, model_state_dicts):
+                    m.load_state_dict(sd, strict=True)

And in the reference load path, persist all chunks instead of the first:

-                reference_model = reference_model[0]
-                reference_model.eval()
-                self.reference_state_dict = {}
-                for name, item in reference_model.state_dict().items():
-                    ...
-                print("Reference model loaded")
+                for rm in reference_model:
+                    rm.eval()
+                self.reference_state_dicts = []
+                for rm in reference_model:
+                    ref_sd = {}
+                    for name, item in rm.state_dict().items():
+                        cpu_item = item.detach().to(device="cpu", non_blocking=True, copy=True) if isinstance(item, torch.Tensor) else item
+                        ref_sd[name] = cpu_item
+                    self.reference_state_dicts.append(ref_sd)
+                print("Reference model loaded")

1811-1860: Teach move_model to handle lists of sub-models.

Centralize list handling to prevent crashes at multiple call sites (prepare_for_lp_inference, offload_before_refit, offload_after_refit).

-    def move_model(
-        self,
-        model: torch.nn.Module,
+    def move_model(
+        self,
+        model: torch.nn.Module | list[torch.nn.Module],
         device: str,
         move_params: bool = True,
         move_grads: bool = True,
     ) -> torch.nn.Module:
-        # move all param and grad buffers to the device
+        # handle lists
+        if isinstance(model, list):
+            return [self.move_model(m, device, move_params, move_grads) for m in model]
+
+        # move all param and grad buffers to the device
         if isinstance(model, DistributedDataParallel):
             ...
         elif isinstance(model, custom_FSDP):
             ...
         else:
             # Ordinary offload case
             if move_params:
                 for name, param in model.state_dict().items():
                     new_state_dict = {}
                     for name, item in model.state_dict().items():
                         if isinstance(item, torch.Tensor):
                             item = item.detach().to(
                                 device=device, non_blocking=True, copy=True
                             )
                         new_state_dict[name] = item
                     model.load_state_dict(new_state_dict)
         return model
🧹 Nitpick comments (3)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)

1-1: Remove moving branch from .gitmodules for Megatron-Bridge

Found .gitmodules: submodule "3rdparty/Megatron-Bridge" contains branch = yifu/nemo-rl-use-chunkpatch-ds; CI fetches submodules recursively. Remove that branch entry so CI pins to the recorded commit, and add an upstream diff (old..new) link in the PR description for auditability.

nemo_rl/distributed/batched_data_dict.py (1)

282-295: Update docs to mention the two new (optional) packing knobs.

Docstring still says “requires five keys”; please add min_bin_count and bin_count_multiple and call them optional with default fallback to shards.

nemo_rl/models/policy/megatron_policy_worker.py (1)

877-881: Reduce noisy prints in hot path.

Prints per-GPU per-GB will spam logs. Gate behind a debug flag or env guard.

-        print(f"rank {get_rank_safe()}: total_dataset_size: {total_dataset_size.item()}")
+        if os.environ.get("NRL_DEBUG_TRAIN", "0") == "1":
+            print(f"rank {get_rank_safe()}: total_dataset_size: {total_dataset_size.item()}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 94a3d49 and 71e5756.

📒 Files selected for processing (5)
  • 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1 hunks)
  • nemo_rl/distributed/batched_data_dict.py (2 hunks)
  • nemo_rl/models/policy/lm_policy.py (2 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (18 hunks)
  • nemo_rl/models/policy/utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
nemo_rl/models/policy/lm_policy.py (1)
nemo_rl/distributed/named_sharding.py (1)
  • get_axis_size (209-211)
nemo_rl/models/policy/megatron_policy_worker.py (2)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
  • train (522-888)
  • save_checkpoint (1487-1505)
nemo_rl/distributed/batched_data_dict.py (1)
  • make_microbatch_iterator_for_packable_sequences (763-777)
🔇 Additional comments (8)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)

1-1: Submodule bump — new SHA present upstream; local diff unavailable; verify upstream changes and branch.

  • Old → New: abd52c89fe969869b8969acc181630c273cca4fd → 30f24c667c3e909f861dd583ca4896552a70d9e9 (remote: https://github.com/NVIDIA-NeMo/Megatron-Bridge.git).
  • Submodule not initialized in this environment; run locally: cd 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge && git fetch --all && git diff --stat abd52c89..30f24c66 or git log --oneline abd52c89..30f24c66.
  • .gitmodules branch: yifu/nemo-rl-use-chunkpatch-ds — confirm this non-default branch is intentional and that the commit is reachable and CI-passing.
  • Review upstream diff for regressions and confirm provenance before merging.
nemo_rl/models/policy/lm_policy.py (1)

359-361: Confirm binning multiplicity with virtual pipeline.

You set min_bin_count/bin_count_multiple = dp_size * pp_size. With virtual pipeline parallel (vp_size > 1), do we also need to factor vp_size so microbatches align per virtual stage? Please confirm against MCore forward_backward expectations.

If needed, the adjustment would be:

- self.sequence_packing_args["min_bin_count"] = dp_size * pp_size
- self.sequence_packing_args["bin_count_multiple"] = dp_size * pp_size
+ vp_size = self.cfg["megatron_cfg"].get("virtual_pipeline_model_parallel_size", 1)
+ self.sequence_packing_args["min_bin_count"] = dp_size * pp_size * vp_size
+ self.sequence_packing_args["bin_count_multiple"] = dp_size * pp_size * vp_size
nemo_rl/models/policy/megatron_policy_worker.py (6)

528-534: Suspicious assignment: microbatch_group_size_per_vp_stage.

Setting microbatch_group_size_per_vp_stage = pipeline_model_parallel_size looks off; this knob typically relates to virtual pipeline grouping. Please verify against MCore config; it may need virtual_pipeline_model_parallel_size or a derived grouping value instead.


554-557: Good: enable overlap_p2p_comm when PP and VP are used.

This should reduce bubbles; disabling batch_p2p_comm here is consistent.


695-756: List-of-model handling during reference-model setup is on the right track.

Moving each sub-model to CPU before instantiating the reference and then restoring to CUDA is correct for multi-chunk models.


952-960: VP-aware iterator list looks correct.

Providing a list of iterators per VP stage when vp_size > 1 aligns with virtual pipeline semantics.


979-1000: Verify forward_backward accepts list iterators for VP.

There’s a TODO; ensure get_forward_backward_func supports a List[Iterator] for data_iterator with VP, and that num_microbatches matches the per-stage iterator lengths.


1721-1724: Handle list-of-modules when calling move_model

prepare_for_lp_inference passes self.model (a list) into move_model, which expects a single module and will fail at runtime. Map over the list here or make move_model accept iterables.

Minimal local fix:

-        self.model = self.move_model(self.model, "cuda", move_grads=False)
+        self.model = [self.move_model(m, "cuda", move_grads=False) for m in self.model]

Prefer generalizing move_model() so all call sites accept lists.

Comment on lines +57 to +58
min_bin_count: Optional[int]
bin_count_multiple: Optional[int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make optional keys truly optional in SequencePackingArgs (typing fix).

As written, Optional[int] in a total=True TypedDict still requires the keys. Use NotRequired so callers can omit them.

Apply:

- from typing_extensions import Self
+ from typing_extensions import Self, NotRequired
-    min_bin_count: Optional[int]
-    bin_count_multiple: Optional[int]
+    min_bin_count: NotRequired[int]
+    bin_count_multiple: NotRequired[int]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_bin_count: Optional[int]
bin_count_multiple: Optional[int]
from typing_extensions import Self, NotRequired
min_bin_count: NotRequired[int]
bin_count_multiple: NotRequired[int]
🤖 Prompt for AI Agents
In nemo_rl/distributed/batched_data_dict.py around lines 57 to 58, the TypedDict
declares min_bin_count: Optional[int] and bin_count_multiple: Optional[int]
which (with total=True) still require the keys; change those annotations to use
NotRequired[int] and import NotRequired (from typing if on Python 3.11+,
otherwise from typing_extensions) so callers can omit the keys, and remove
Optional for those fields.

Comment on lines 414 to 416
min_bin_count=sequence_packing_args["min_bin_count"] or shards,
bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid KeyError when optional fields are absent.

Accessing optional keys via dict indexing will raise KeyError. Use .get(...) with fallback.

-                min_bin_count=sequence_packing_args["min_bin_count"] or shards,
-                bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards,
+                min_bin_count=sequence_packing_args.get("min_bin_count") or shards,
+                bin_count_multiple=sequence_packing_args.get("bin_count_multiple") or shards,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_bin_count=sequence_packing_args["min_bin_count"] or shards,
bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards,
)
min_bin_count=sequence_packing_args.get("min_bin_count") or shards,
bin_count_multiple=sequence_packing_args.get("bin_count_multiple") or shards,
)
🤖 Prompt for AI Agents
In nemo_rl/distributed/batched_data_dict.py around lines 414 to 416, the code
indexes optional keys on sequence_packing_args which can raise KeyError; change
these accesses to use .get(...) with a shards fallback (e.g.
sequence_packing_args.get("min_bin_count", shards) and
sequence_packing_args.get("bin_count_multiple", shards)), and also guard for
sequence_packing_args being None by using (sequence_packing_args or {}).get(...,
shards) so missing or absent dicts won't raise.

@terrykong terrykong requested a review from a team September 17, 2025 22:16
@euronymous-aithal
Copy link
Contributor

@yaoyu-33 @yfw can we get a review for this ?

@guyueh1
Copy link
Contributor Author

guyueh1 commented Nov 6, 2025

@parthmannan do you have a different PR that supports asymmetric VPP? if so maybe we should close this and work on your PR as that covers more cases

@parthmannan
Copy link
Contributor

@parthmannan do you have a different PR that supports asymmetric VPP? if so maybe we should close this and work on your PR as that covers more cases

It would be a small change on top of this. I haven't opened it yet as I got pulled into something urgent. I'll try to get it open with your base vpp changes in as well by tomorrow.

@guyueh1
Copy link
Contributor Author

guyueh1 commented Nov 9, 2025

@parthmannan let's merge all changes related to vpp in one PR; I can drive this, if you just merge your changes to my branch (guyueh1/support_mcore_vpp)

@guyueh1 guyueh1 requested review from a team as code owners November 9, 2025 22:25
@guyueh1 guyueh1 self-assigned this Nov 9, 2025
@guyueh1 guyueh1 added the CI:L0 Run doctests and unit tests label Nov 9, 2025
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 requested a review from a team as a code owner November 10, 2025 17:41
@guyueh1 guyueh1 added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Nov 10, 2025
@guyueh1 guyueh1 requested a review from parthmannan November 10, 2025 17:41
@guyueh1 guyueh1 requested a review from a team as a code owner November 11, 2025 00:33
@guyueh1 guyueh1 added CI:L0 Run doctests and unit tests and removed CI:L0 Run doctests and unit tests labels Nov 17, 2025
@guyueh1 guyueh1 changed the title feat: Support mcore virtual pipeline parallel perf: Support mcore virtual pipeline parallel Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support VPP in MCORE path to reduce policy train pipeline bubbles

3 participants