Skip to content
Closed

Part 1 #4083

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/multimodal/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
is_pipeline_last_stage,
)
from megatron.training import get_args, get_timers, get_tokenizer, pretrain
from megatron.training.utils import is_last_rank, get_batch_on_this_cp_rank
from megatron.core.utils import get_batch_on_this_cp_rank
from megatron.training.utils import is_last_rank


def get_batch(data_iterator, image_token_index, img_seq_len):
Expand Down
4 changes: 3 additions & 1 deletion examples/post_training/modelopt/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
)
from utils import get_hf_tokenizer
from model_provider import model_provider
from megatron.core.parallel_state import get_context_parallel_group


REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
Expand Down Expand Up @@ -435,7 +437,7 @@ def get_batch(data_iterator):
batch["hidden_states"] = feature_b["hidden_states"].transpose(0, 1)[:args.seq_length]

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
batch = get_batch_on_this_cp_rank(batch, is_hybrid_cp=False, cp_group=get_context_parallel_group())

return batch

Expand Down
5 changes: 4 additions & 1 deletion examples/post_training/modelopt/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
mtq_luts = None
warnings.warn("luts is not installed. LUTs quantization configs will not be available.")

from megatron.core.parallel_state import get_context_parallel_group
from megatron.core.utils import get_batch_on_this_cp_rank
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
Expand Down Expand Up @@ -410,7 +411,9 @@ def _dataset_forward_loop_func(model):
batch_size=args.calib_batch_size,
)
for sample in tqdm(dataloader, disable=torch.distributed.get_rank()):
sample = get_batch_on_this_cp_rank(sample)
sample = get_batch_on_this_cp_rank(
sample, is_hybrid_cp=False, cp_group=get_context_parallel_group()
)
simple_generate(model, sample["input_ids"], osl=1, calibration_mode=True)

unwrapped_model = unwrap_model(model)[0]
Expand Down
5 changes: 4 additions & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def __init__(
self.rotary_scaling = rope_scaling
self.mtp_block_spec = mtp_block_spec
self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank(
self.config, ignore_virtual=False, vp_stage=vp_stage
layout=self.config.pipeline_model_parallel_layout,
mtp_num_layers=self.config.mtp_num_layers,
ignore_virtual=False,
vp_stage=vp_stage,
)

if self.pre_process or self.mtp_process:
Expand Down
7 changes: 6 additions & 1 deletion megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,12 @@ def __init__(
# to split the hybrid layer pattern into pipeline stages before parsing the pattern for
# the current pipeline stage. This could also enable MTP standalone (MTP in a pipeline
# stage separate from loss) to be supported in the hybrid model.
and mtp_on_this_rank(self.config, ignore_virtual=False, vp_stage=self.vp_stage)
and mtp_on_this_rank(
layout=self.config.pipeline_model_parallel_layout,
mtp_num_layers=self.config.mtp_num_layers,
ignore_virtual=False,
vp_stage=self.vp_stage,
)
)

# megatron core pipelining currently depends on model type
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/models/mimo/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _apply_context_parallel(
batch["attention_mask"] = attention_mask

if packed_seq_params is None or getattr(packed_seq_params, 'qkv_format', 'sbhd') == 'sbhd':
batch = get_batch_on_this_cp_rank(batch)
batch = get_batch_on_this_cp_rank(batch, is_hybrid_cp=False, cp_group=self.cfg.cp_group)
else:
assert _HAVE_TEX and is_te_min_version("1.10.0"), (
"Please update Transformer Engine to >= 1.10 "
Expand Down
7 changes: 5 additions & 2 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,12 @@ def _process_embedding_token_parallel(
batch["new_loss_mask"] = new_loss_mask
# Distribute sequence across CP ranks
if packed_seq_params is None or packed_seq_params.qkv_format == 'sbhd':
from megatron.training.utils import get_batch_on_this_cp_rank
from megatron.core.parallel_state import get_context_parallel_group
from megatron.core.utils import get_batch_on_this_cp_rank

batch = get_batch_on_this_cp_rank(batch)
batch = get_batch_on_this_cp_rank(
batch, is_hybrid_cp=False, cp_group=get_context_parallel_group()
)
else:
assert HAVE_TEX and is_te_min_version(
"1.10.0"
Expand Down
20 changes: 12 additions & 8 deletions megatron/core/transformer/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
else:
TESpecProvider = None

from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout


def tie_word_embeddings_state_dict(
sharded_state_dict: ShardedStateDict,
Expand Down Expand Up @@ -468,13 +470,16 @@ def get_mtp_layer_spec_for_backend(


def mtp_on_this_rank(
config: TransformerConfig, ignore_virtual: Optional[bool] = True, vp_stage: Optional[int] = None
layout: PipelineParallelLayerLayout = None,
mtp_num_layers: Optional[int] = None,
ignore_virtual: Optional[bool] = True,
vp_stage: Optional[int] = None,
) -> bool:
"""
Check if there is MTP on the current rank.

Behavior:
- If a custom pipeline model parallel layout is provided in the config:
- If a custom pipeline model parallel layout is provided:
- If virtual pipeline parallelism is enabled (and `ignore_virtual` is False), checks
whether any MTP layers are present on this (pp_rank, vp_stage) pair.
- Otherwise, checks all virtual pipeline ranks of the current pipeline rank. Returns
Expand All @@ -484,25 +489,24 @@ def mtp_on_this_rank(
"""
mtp_on_this_rank = False
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
if config.pipeline_model_parallel_layout is not None:
if layout is not None:
# with custom PP layout, we support put MTP layers on any pipeline stage
layout = config.pipeline_model_parallel_layout.layout
if (
not ignore_virtual
and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None
):
assert vp_stage is not None, "vp_stage must be passed if virtual pipeline is enabled"
num_layers_to_build = layout[pp_rank][vp_stage].count(LayerType.mtp)
num_layers_to_build = layout.layout[pp_rank][vp_stage].count(LayerType.mtp)
mtp_on_this_rank = num_layers_to_build > 0
else:
for vpp_rank in range(len(layout[pp_rank])):
num_layers_to_build = layout[pp_rank][vpp_rank].count(LayerType.mtp)
for vpp_rank in range(len(layout.layout[pp_rank])):
num_layers_to_build = layout.layout[pp_rank][vpp_rank].count(LayerType.mtp)
if num_layers_to_build > 0:
mtp_on_this_rank = True
break
else:
# without custom PP layout, we only support put all of MTP layers on the last pipeline stage
if config.mtp_num_layers is not None:
if mtp_num_layers is not None:
mtp_on_this_rank = parallel_state.is_pipeline_last_stage(
ignore_virtual=ignore_virtual, vp_stage=vp_stage
)
Expand Down
Loading