-
Notifications
You must be signed in to change notification settings - Fork 220
perf: Support mcore virtual pipeline parallel #1126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
📝 WalkthroughWalkthroughUpdates include a submodule pointer bump, introduction of two optional sequence packing parameters, pipeline-parallel-aware binning in LM policy training, broad refactor of MegatronPolicyWorker to handle a list of models (multi-model/pipeline/virtual-pipeline aware), and a corresponding utility change to read modules from the first model in the list. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Trainer
participant LMPolicy
participant ParallelConfig as Parallel Config
participant SeqPack as Sequence Packing Args
participant BinPacker
Trainer->>LMPolicy: start_training()
LMPolicy->>ParallelConfig: get dp_size, pp_size
ParallelConfig-->>LMPolicy: dp_size, pp_size
LMPolicy->>SeqPack: set min_bin_count = dp_size*pp_size\nset bin_count_multiple = dp_size*pp_size
LMPolicy->>BinPacker: construct(..., min_bin_count, bin_count_multiple)
BinPacker-->>LMPolicy: ready
LMPolicy-->>Trainer: proceed with training
sequenceDiagram
autonumber
actor Orchestrator
participant Worker as MegatronPolicyWorker
participant Models as [Model_0, Model_1, ...]
participant DataIt as Data Iterator(s)
participant Checkpoint as Checkpoint/Export
Orchestrator->>Worker: initialize(models=list)
Worker->>Models: for each: move to device, set hooks
alt virtual_pipeline_model_parallel_size > 1
Worker->>DataIt: create list of iterators (per VP stage)
else
Worker->>DataIt: create single iterator
end
loop training steps
Worker->>Models: for each: zero grads, reset caches
Worker->>Models: forward/backward (passes list to step fn)
end
Worker->>Checkpoint: save(model=list), export weights
Worker-->>Orchestrator: done
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
nemo_rl/models/policy/utils.py (1)
15-18: get_gpu_info should accept single model or list of models.Using model[0] assumes indexability and breaks dtensor paths that pass a single nn.Module. Make it robust and update typing.
-from typing import Any, Dict +from typing import Any, Dict, Sequence-def get_gpu_info(model: torch.nn.Module) -> dict[str, Any]: +def get_gpu_info(model: torch.nn.Module | Sequence[torch.nn.Module]) -> dict[str, Any]:- for module_name, module in model[0].named_modules(): + first_model = model[0] if isinstance(model, (list, tuple)) else model + for module_name, module in first_model.named_modules():Also applies to: 103-104, 134-146
nemo_rl/models/policy/megatron_policy_worker.py (2)
1312-1376: use_reference_model is broken with list-of-models (loads single state_dict into list).self.model is now a list; this context manager calls self.model.state_dict(), which will crash, and reference_state_dict holds only rank-0 chunk. Load/save must iterate all sub-models.
Apply:
- try: - # Save original references - model_state_dict = {} - for name, item in self.model.state_dict().items(): - if isinstance(item, torch.Tensor): - item = item.detach().to( - device="cpu", non_blocking=True, copy=True - ) - model_state_dict[name] = item - - self.model.load_state_dict(self.reference_state_dict, strict=True) + try: + # Save original references for all sub-models + model_state_dicts = [] + for m in self.model: + sd = {} + for name, item in m.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device="cpu", non_blocking=True, copy=True) + sd[name] = item + model_state_dicts.append(sd) + + # Load per-submodel reference weights + assert hasattr(self, "reference_state_dicts"), "reference_state_dicts missing" + for m, ref_sd in zip(self.model, self.reference_state_dicts): + m.load_state_dict(ref_sd, strict=True) @@ - # Restore original references and device placement - self.model.load_state_dict(model_state_dict, strict=True) + # Restore original references per sub-model + for m, sd in zip(self.model, model_state_dicts): + m.load_state_dict(sd, strict=True)And in the reference load path, persist all chunks instead of the first:
- reference_model = reference_model[0] - reference_model.eval() - self.reference_state_dict = {} - for name, item in reference_model.state_dict().items(): - ... - print("Reference model loaded") + for rm in reference_model: + rm.eval() + self.reference_state_dicts = [] + for rm in reference_model: + ref_sd = {} + for name, item in rm.state_dict().items(): + cpu_item = item.detach().to(device="cpu", non_blocking=True, copy=True) if isinstance(item, torch.Tensor) else item + ref_sd[name] = cpu_item + self.reference_state_dicts.append(ref_sd) + print("Reference model loaded")
1811-1860: Teach move_model to handle lists of sub-models.Centralize list handling to prevent crashes at multiple call sites (prepare_for_lp_inference, offload_before_refit, offload_after_refit).
- def move_model( - self, - model: torch.nn.Module, + def move_model( + self, + model: torch.nn.Module | list[torch.nn.Module], device: str, move_params: bool = True, move_grads: bool = True, ) -> torch.nn.Module: - # move all param and grad buffers to the device + # handle lists + if isinstance(model, list): + return [self.move_model(m, device, move_params, move_grads) for m in model] + + # move all param and grad buffers to the device if isinstance(model, DistributedDataParallel): ... elif isinstance(model, custom_FSDP): ... else: # Ordinary offload case if move_params: for name, param in model.state_dict().items(): new_state_dict = {} for name, item in model.state_dict().items(): if isinstance(item, torch.Tensor): item = item.detach().to( device=device, non_blocking=True, copy=True ) new_state_dict[name] = item model.load_state_dict(new_state_dict) return model
🧹 Nitpick comments (3)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)
1-1: Remove moving branch from .gitmodules for Megatron-BridgeFound .gitmodules: submodule "3rdparty/Megatron-Bridge" contains
branch = yifu/nemo-rl-use-chunkpatch-ds; CI fetches submodules recursively. Remove thatbranchentry so CI pins to the recorded commit, and add an upstream diff (old..new) link in the PR description for auditability.nemo_rl/distributed/batched_data_dict.py (1)
282-295: Update docs to mention the two new (optional) packing knobs.Docstring still says “requires five keys”; please add min_bin_count and bin_count_multiple and call them optional with default fallback to shards.
nemo_rl/models/policy/megatron_policy_worker.py (1)
877-881: Reduce noisy prints in hot path.Prints per-GPU per-GB will spam logs. Gate behind a debug flag or env guard.
- print(f"rank {get_rank_safe()}: total_dataset_size: {total_dataset_size.item()}") + if os.environ.get("NRL_DEBUG_TRAIN", "0") == "1": + print(f"rank {get_rank_safe()}: total_dataset_size: {total_dataset_size.item()}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge(1 hunks)nemo_rl/distributed/batched_data_dict.py(2 hunks)nemo_rl/models/policy/lm_policy.py(2 hunks)nemo_rl/models/policy/megatron_policy_worker.py(18 hunks)nemo_rl/models/policy/utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
nemo_rl/models/policy/lm_policy.py (1)
nemo_rl/distributed/named_sharding.py (1)
get_axis_size(209-211)
nemo_rl/models/policy/megatron_policy_worker.py (2)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
train(522-888)save_checkpoint(1487-1505)nemo_rl/distributed/batched_data_dict.py (1)
make_microbatch_iterator_for_packable_sequences(763-777)
🔇 Additional comments (8)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)
1-1: Submodule bump — new SHA present upstream; local diff unavailable; verify upstream changes and branch.
- Old → New: abd52c89fe969869b8969acc181630c273cca4fd → 30f24c667c3e909f861dd583ca4896552a70d9e9 (remote: https://github.com/NVIDIA-NeMo/Megatron-Bridge.git).
- Submodule not initialized in this environment; run locally:
cd 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge && git fetch --all && git diff --stat abd52c89..30f24c66orgit log --oneline abd52c89..30f24c66.- .gitmodules branch: yifu/nemo-rl-use-chunkpatch-ds — confirm this non-default branch is intentional and that the commit is reachable and CI-passing.
- Review upstream diff for regressions and confirm provenance before merging.
nemo_rl/models/policy/lm_policy.py (1)
359-361: Confirm binning multiplicity with virtual pipeline.You set min_bin_count/bin_count_multiple = dp_size * pp_size. With virtual pipeline parallel (vp_size > 1), do we also need to factor vp_size so microbatches align per virtual stage? Please confirm against MCore forward_backward expectations.
If needed, the adjustment would be:
- self.sequence_packing_args["min_bin_count"] = dp_size * pp_size - self.sequence_packing_args["bin_count_multiple"] = dp_size * pp_size + vp_size = self.cfg["megatron_cfg"].get("virtual_pipeline_model_parallel_size", 1) + self.sequence_packing_args["min_bin_count"] = dp_size * pp_size * vp_size + self.sequence_packing_args["bin_count_multiple"] = dp_size * pp_size * vp_sizenemo_rl/models/policy/megatron_policy_worker.py (6)
528-534: Suspicious assignment: microbatch_group_size_per_vp_stage.Setting microbatch_group_size_per_vp_stage = pipeline_model_parallel_size looks off; this knob typically relates to virtual pipeline grouping. Please verify against MCore config; it may need virtual_pipeline_model_parallel_size or a derived grouping value instead.
554-557: Good: enable overlap_p2p_comm when PP and VP are used.This should reduce bubbles; disabling batch_p2p_comm here is consistent.
695-756: List-of-model handling during reference-model setup is on the right track.Moving each sub-model to CPU before instantiating the reference and then restoring to CUDA is correct for multi-chunk models.
952-960: VP-aware iterator list looks correct.Providing a list of iterators per VP stage when vp_size > 1 aligns with virtual pipeline semantics.
979-1000: Verify forward_backward accepts list iterators for VP.There’s a TODO; ensure get_forward_backward_func supports a List[Iterator] for data_iterator with VP, and that num_microbatches matches the per-stage iterator lengths.
1721-1724: Handle list-of-modules when calling move_modelprepare_for_lp_inference passes self.model (a list) into move_model, which expects a single module and will fail at runtime. Map over the list here or make move_model accept iterables.
Minimal local fix:
- self.model = self.move_model(self.model, "cuda", move_grads=False) + self.model = [self.move_model(m, "cuda", move_grads=False) for m in self.model]Prefer generalizing move_model() so all call sites accept lists.
| min_bin_count: Optional[int] | ||
| bin_count_multiple: Optional[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Make optional keys truly optional in SequencePackingArgs (typing fix).
As written, Optional[int] in a total=True TypedDict still requires the keys. Use NotRequired so callers can omit them.
Apply:
- from typing_extensions import Self
+ from typing_extensions import Self, NotRequired- min_bin_count: Optional[int]
- bin_count_multiple: Optional[int]
+ min_bin_count: NotRequired[int]
+ bin_count_multiple: NotRequired[int]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| min_bin_count: Optional[int] | |
| bin_count_multiple: Optional[int] | |
| from typing_extensions import Self, NotRequired | |
| min_bin_count: NotRequired[int] | |
| bin_count_multiple: NotRequired[int] |
🤖 Prompt for AI Agents
In nemo_rl/distributed/batched_data_dict.py around lines 57 to 58, the TypedDict
declares min_bin_count: Optional[int] and bin_count_multiple: Optional[int]
which (with total=True) still require the keys; change those annotations to use
NotRequired[int] and import NotRequired (from typing if on Python 3.11+,
otherwise from typing_extensions) so callers can omit the keys, and remove
Optional for those fields.
| min_bin_count=sequence_packing_args["min_bin_count"] or shards, | ||
| bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid KeyError when optional fields are absent.
Accessing optional keys via dict indexing will raise KeyError. Use .get(...) with fallback.
- min_bin_count=sequence_packing_args["min_bin_count"] or shards,
- bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards,
+ min_bin_count=sequence_packing_args.get("min_bin_count") or shards,
+ bin_count_multiple=sequence_packing_args.get("bin_count_multiple") or shards,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| min_bin_count=sequence_packing_args["min_bin_count"] or shards, | |
| bin_count_multiple=sequence_packing_args["bin_count_multiple"] or shards, | |
| ) | |
| min_bin_count=sequence_packing_args.get("min_bin_count") or shards, | |
| bin_count_multiple=sequence_packing_args.get("bin_count_multiple") or shards, | |
| ) |
🤖 Prompt for AI Agents
In nemo_rl/distributed/batched_data_dict.py around lines 414 to 416, the code
indexes optional keys on sequence_packing_args which can raise KeyError; change
these accesses to use .get(...) with a shards fallback (e.g.
sequence_packing_args.get("min_bin_count", shards) and
sequence_packing_args.get("bin_count_multiple", shards)), and also guard for
sequence_packing_args being None by using (sequence_packing_args or {}).get(...,
shards) so missing or absent dicts won't raise.
|
@parthmannan do you have a different PR that supports asymmetric VPP? if so maybe we should close this and work on your PR as that covers more cases |
It would be a small change on top of this. I haven't opened it yet as I got pulled into something urgent. I'll try to get it open with your base vpp changes in as well by tomorrow. |
|
@parthmannan let's merge all changes related to vpp in one PR; I can drive this, if you just merge your changes to my branch (guyueh1/support_mcore_vpp) |
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Support asymmetric VPP
What does this PR do ?
Support virtual pipeline parallel (vpp) in mcore
Issues
closes #1038
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit