Describe the bug
When performing SFT with data packing enabled, the get_batch_on_this_tp_rank function fails to broadcast cu_seqlens and max_seqlen to Tensor Parallel (TP) ranks in intermediate Pipeline Parallel (PP) stages (where PP > 2).
This results in cu_seqlens being None on TP rank > 0 for any stage that is neither the first nor the last. Consequently, the Transformer layer (specifically RoPE and Flash Attention) falls back to non-packed logic, leading to a dimension mismatch error:
RuntimeError: Tensors must have same number of dimensions: got 4 and 3.
Environment
Hardware: Multi-node (e.g., 2 nodes x 8 GPUs)
Megatron-LM Version: [main branch at 26-04-01]
Configuration: TP=2, PP=4 (or any PP > 2), SFT Packing enabled.
Model: deepseek v3 construct model with 4 layers
Steps to Reproduce:
Enable SFT with data packing (providing cu_seqlens in the dataset).
Set pipeline-model-parallel-size to a value greater than 2 (e.g., PP=4).
Set tensor-model-parallel-size to a value greater than 1 (e.g., TP=2).
Start training. The process will crash during the forward pass of the 2nd or 3rd pipeline stage.
Expected Behavior:
cu_seqlens and max_seqlen should be synchronized across all TP ranks in every pipeline stage, as every Transformer layer requires these tensors for packed sequence indexing and RoPE calculation.
Actual Behavior (Error Stack Trace):
File "megatron/core/models/common/embeddings/rope_utils.py", line 300, in apply_rotary_pos_emb
return _apply_rotary_pos_emb_bshd(...)
File "megatron/core/models/common/embeddings/rope_utils.py", line 126, in _apply_rotary_pos_emb_bshd
return torch.cat((t, t_pass), dim=-1)
RuntimeError: Tensors must have same number of dimensions: got 4 and 3
Suggested Fix
Modify get_batch_on_this_tp_rank in megatron/training/utils.py to ensure intermediate stages also perform the broadcast for metadata required by packed sequences.
# In get_batch_on_this_tp_rank
if mpu.is_pipeline_first_stage():
# ... existing logic ...
elif mpu.is_pipeline_last_stage():
# ... existing logic ...
_broadcast_cu_seqlens(batch['cu_seqlens']) # Missing in current version
_broadcast(batch['max_seqlen']) # Missing in current version
else:
# Intermediate stages (Middle stages)
_broadcast(batch['attention_mask'])
_broadcast_cu_seqlens(batch['cu_seqlens']) # Critical fix
_broadcast(batch['max_seqlen']) # Critical fix
when I switch to v_core0.16.1, the blew code in training/utils.py make max_seq to None
if args.sft:
max_seqlen = torch.empty(
1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
else:
max_seqlen = None
cu_seqlens = None
max_seqlen = torch.empty(
1,
dtype=torch.int32,
device=torch.cuda.current_device(),
) if args.hybrid_context_parallel else None
and trigger an error in pretrain_gpt.py, beause None has no dim() :
if cu_seqlens is not None:
assert (
cu_seqlens.dim() == 2 and cu_seqlens.shape[0] == 1
), "micro-batch-size must be 1 for packing"
cu_seqlens = cu_seqlens[0]
assert max_seqlen.dim() == 1
Describe the bug
When performing SFT with data packing enabled, the get_batch_on_this_tp_rank function fails to broadcast cu_seqlens and max_seqlen to Tensor Parallel (TP) ranks in intermediate Pipeline Parallel (PP) stages (where PP > 2).
This results in cu_seqlens being None on TP rank > 0 for any stage that is neither the first nor the last. Consequently, the Transformer layer (specifically RoPE and Flash Attention) falls back to non-packed logic, leading to a dimension mismatch error:
RuntimeError: Tensors must have same number of dimensions: got 4 and 3.
Environment
Hardware: Multi-node (e.g., 2 nodes x 8 GPUs)
Megatron-LM Version: [main branch at 26-04-01]
Configuration: TP=2, PP=4 (or any PP > 2), SFT Packing enabled.
Model: deepseek v3 construct model with 4 layers
Steps to Reproduce:
Enable SFT with data packing (providing cu_seqlens in the dataset).
Set pipeline-model-parallel-size to a value greater than 2 (e.g., PP=4).
Set tensor-model-parallel-size to a value greater than 1 (e.g., TP=2).
Start training. The process will crash during the forward pass of the 2nd or 3rd pipeline stage.
Expected Behavior:
cu_seqlens and max_seqlen should be synchronized across all TP ranks in every pipeline stage, as every Transformer layer requires these tensors for packed sequence indexing and RoPE calculation.
Actual Behavior (Error Stack Trace):
Suggested Fix
Modify get_batch_on_this_tp_rank in megatron/training/utils.py to ensure intermediate stages also perform the broadcast for metadata required by packed sequences.
when I switch to v_core0.16.1, the blew code in training/utils.py make max_seq to None
and trigger an error in pretrain_gpt.py, beause None has no dim() :