Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ policy: &POLICY_BASE
pipeline_model_parallel_size: 2
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
context_parallel_size: 2
pipeline_dtype: ${policy.precision}
sequence_parallel: false
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ policy:
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ policy:
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
sequence_parallel: false
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ policy:
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
sequence_parallel: false
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
gradient_accumulation_fusion: false

optimizer:
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ policy:
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
sequence_parallel: false
freeze_moe_router: false
moe_router_dtype: null
Expand Down
2 changes: 2 additions & 0 deletions examples/nemo_gym/grpo_nanov3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ policy:
pipeline_model_parallel_size: 2
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
context_parallel_size: 4
pipeline_dtype: ${policy.precision}
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ policy:
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
virtual_pipeline_model_parallel_size: null
pipeline_model_parallel_layout: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
sequence_parallel: false
Expand Down
18 changes: 15 additions & 3 deletions nemo_rl/distributed/batched_data_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

import torch
from typing_extensions import Self
from typing_extensions import NotRequired, Self

from nemo_rl.data.multimodal_utils import (
PackedTensor,
Expand All @@ -54,6 +54,16 @@ class SequencePackingArgs(TypedDict):
sequence_length_pad_multiple: (
int # pad each sequence to a multiple of this value (for CP/TP alignment)
)
min_bin_count: (
NotRequired[
int
] # Minimum number of bins to create, even if fewer would suffice
)
bin_count_multiple: (
NotRequired[
int
] # If specified, the total number of bins will be divisible by this value
)


class DynamicBatchingArgs(TypedDict):
Expand Down Expand Up @@ -431,8 +441,10 @@ def shard_by_batch_size(
algorithm=sequence_packing_args["algorithm"],
bin_capacity=sequence_packing_args["max_tokens_per_microbatch"],
collect_metrics=False, # TODO(ahmadki): make configurable
min_bin_count=shards,
bin_count_multiple=shards,
min_bin_count=(sequence_packing_args.get("min_bin_count") or shards),
bin_count_multiple=(
sequence_packing_args.get("bin_count_multiple") or shards
),
)

input_lengths_key = sequence_packing_args["input_lengths_key"]
Expand Down
6 changes: 5 additions & 1 deletion nemo_rl/models/megatron/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ class ModelAndOptimizerState(NamedTuple):
"""

state: GlobalState
model: MegatronModule
model: (
list[
MegatronModule
] # every instance is a model chunk corresponding to a virtual pipeline stage
)
optimizer: MegatronOptimizer
scheduler: OptimizerParamScheduler
checkpointing_context: dict[str, Any]
Expand Down
37 changes: 23 additions & 14 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,15 @@ def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None:
]
model_cfg.sequence_parallel = config["megatron_cfg"]["sequence_parallel"]
model_cfg.context_parallel_size = config["megatron_cfg"]["context_parallel_size"]
model_cfg.virtual_pipeline_model_parallel_size = config["megatron_cfg"][
"virtual_pipeline_model_parallel_size"
]
model_cfg.pipeline_model_parallel_layout = config["megatron_cfg"][
"pipeline_model_parallel_layout"
]
model_cfg.microbatch_group_size_per_vp_stage = config["megatron_cfg"][
"pipeline_model_parallel_size"
]

if model_cfg.context_parallel_size > 1:
assert config["sequence_packing"]["enabled"], (
Expand Down Expand Up @@ -936,9 +945,6 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
if len(model) == 1:
param_sync_func = param_sync_func[0]

# Get the first model from the list
model = model[0]

return ModelAndOptimizerState(
state,
model,
Expand Down Expand Up @@ -1088,16 +1094,19 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
reference_state_dict = {}

if should_load_checkpoint or use_peft:
reference_model = reference_model[0]
reference_model.eval()
# Store reference state dict on CPU
for name, item in reference_model.state_dict().items():
if isinstance(item, torch.Tensor):
cpu_item = item.detach().to(device="cpu", non_blocking=True, copy=True)
del item
else:
cpu_item = item
reference_state_dict[name] = cpu_item
for chunk in reference_model:
chunk.eval()
chunk_state_dict = {}
for name, item in chunk.state_dict().items():
if isinstance(item, torch.Tensor):
cpu_item = item.detach().to(
device="cpu", non_blocking=True, copy=True
)
del item
else:
cpu_item = item
chunk_state_dict[name] = cpu_item
reference_state_dict.update(chunk_state_dict)
print("Reference model loaded")
else:
print("Reference model not loaded")
Expand All @@ -1119,7 +1128,7 @@ def finalize_megatron_setup(
Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size)
"""
_update_model_config_funcs(
[model],
model,
megatron_cfg.model,
megatron_cfg.ddp,
optimizer,
Expand Down
10 changes: 10 additions & 0 deletions nemo_rl/models/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import defaultdict
from contextlib import nullcontext
from functools import partial
from itertools import tee
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -311,6 +312,15 @@ def megatron_forward_backward(
use_linear_ce_fusion_loss=use_linear_ce_fusion_loss,
)
forward_backward_func = get_forward_backward_func()
# The interleaved pipeline schedule (VPP) requires data_iterator to be a list with
# one independent iterator per model chunk. Each chunk processes ALL num_microbatches
# in sequence, so every iterator must be able to yield num_microbatches items
# independently. We materialise the microbatches once and hand each chunk its own
# iterator over the same data. The non-interleaved schedule accepts a length-1 list
# and unwraps it itself.
num_model_chunks = len(model) if isinstance(model, list) else 1
if num_model_chunks > 1:
data_iterator = list(tee(data_iterator, num_model_chunks))
return forward_backward_func(
forward_step_func=forward_step,
data_iterator=data_iterator,
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class MegatronConfig(TypedDict):
pipeline_model_parallel_size: int
num_layers_in_first_pipeline_stage: int | None
num_layers_in_last_pipeline_stage: int | None
virtual_pipeline_model_parallel_size: int | None
pipeline_model_parallel_layout: Union[str, list] | None
context_parallel_size: int
pipeline_dtype: str
sequence_parallel: bool
Expand Down
16 changes: 16 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ def __init__(
"input_lengths_key": "input_lengths",
"sequence_length_pad_multiple": sequence_length_pad_multiple,
}
# when virtual pipeline parallelism is enabled, the number of microbatches must
# be divisible by pp_size, so we need to pass the correct min_bin_count and bin_count_multiple.
dp_size = self.sharding_annotations.get_axis_size("data_parallel")
vpp_size = (
config["megatron_cfg"]["virtual_pipeline_model_parallel_size"] or 1
)
vpp_layout = config["megatron_cfg"]["pipeline_model_parallel_layout"]
make_num_microbatch_divisible_by = None
if vpp_size > 1 or vpp_layout is not None:
make_num_microbatch_divisible_by = dp_size * pp_size
self.sequence_packing_args["min_bin_count"] = (
make_num_microbatch_divisible_by
)
self.sequence_packing_args["bin_count_multiple"] = (
make_num_microbatch_divisible_by
)
assert not config["dynamic_batching"]["enabled"], (
"Sequence Packing is exclusive of Dynamic Batching. Please disable Dynamic Batching"
)
Expand Down
Loading
Loading