Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions megatron/core/models/common/embeddings/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def _apply_rotary_pos_emb_thd(
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
offsets: Optional[Tensor] = None,
) -> Tensor:
"""A baseline implementation of applying RoPE for `thd` format.

Expand All @@ -206,7 +207,10 @@ def _apply_rotary_pos_emb_thd(
# Handle two different frequency tensor formats:
# 1. If freqs.size(0) == cu_seqlens[-1]: freqs contains all positions across all sequences
# -> Use offset-based mapping for exact positional correspondence
# 2. Otherwise: freqs contains only max sequence length positions
# 2. If offsets are provided: t is a local packed shard and freqs contains max sequence length
# positions.
# -> Use per-fragment offsets to recover the correct global positions.
# 3. Otherwise: freqs contains only max sequence length positions
# -> Use traditional mapping without offsets (map first :seqlen part)
if freqs.dim() >= 1 and freqs.size(0) == cu_seqlens[-1]:
# CASE 1: Exact mapping with offsets
Expand All @@ -222,6 +226,31 @@ def _apply_rotary_pos_emb_thd(

freqs_packed = torch.cat(freq_slices, dim=0)

return _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
freqs_packed,
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
).squeeze(1)
elif offsets is not None:
# CASE 2: Local packed shards with per-fragment offsets.
sequence_splits = torch.split(t, seqlens)
if offsets.numel() != len(sequence_splits):
raise ValueError(
f"offsets must provide one entry per local sequence split, got {offsets.numel()} "
f"offsets for {len(sequence_splits)} splits."
)
freqs_packed = torch.cat(
[
_get_thd_freqs_on_this_cp_rank(
cp_rank, cp_size, x, freqs, int(seq_start_offset.item())
)
for x, seq_start_offset in zip(sequence_splits, offsets)
],
dim=0,
)
Comment thread
kaimo455 marked this conversation as resolved.

return _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
freqs_packed,
Expand All @@ -230,7 +259,7 @@ def _apply_rotary_pos_emb_thd(
mscale=mscale,
).squeeze(1)
else:
# CASE 2: Traditional mapping without offsets
# CASE 3: Traditional mapping without offsets
# Build packed freqs for all sequences using the standard mapping, then apply once
sequence_splits = torch.split(t, seqlens)
freqs_packed = torch.cat(
Expand All @@ -253,8 +282,9 @@ def apply_rotary_pos_emb(
config: TransformerConfig,
cu_seqlens: Optional[Tensor] = None,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
):
cp_group: Optional[torch.distributed.ProcessGroup] = None,
offsets: Optional[Tensor] = None,
) -> Tensor:
"""
Reroute to the appropriate apply_rotary_pos_emb function depending on
fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
Expand Down Expand Up @@ -286,15 +316,22 @@ def apply_rotary_pos_emb(
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb(t, freqs, interleaved=config.rotary_interleaved)
else:
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
cp_size=cp_group.size(),
cp_rank=cp_group.rank(),
interleaved=config.rotary_interleaved,
)
if offsets is not None:
warnings.warn(
"offsets are not supported by fused THD RoPE. Using unfused implementation."
)
else:
assert (
fused_apply_rotary_pos_emb_thd is not None
), "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
cp_size=cp_group.size(),
cp_rank=cp_group.rank(),
interleaved=config.rotary_interleaved,
)
# use unfused implementation
if cu_seqlens is None:
return _apply_rotary_pos_emb_bshd(
Expand All @@ -313,6 +350,7 @@ def apply_rotary_pos_emb(
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
cp_group=cp_group,
offsets=offsets,
)


Expand Down
109 changes: 108 additions & 1 deletion tests/unit_tests/transformer/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest
import torch

from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.core.models.common.embeddings import apply_rotary_pos_emb, rope_utils
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
MultimodalRotaryEmbedding,
RotaryEmbedding,
Expand Down Expand Up @@ -150,3 +151,109 @@ def test_gpu_forward(self):
), f"Output sizes do not match for K: {k_out.shape} != {k_out_ref.shape}"
assert torch.allclose(q_out_ref, q_out), f"Outputs do not match for Q"
assert torch.allclose(k_out_ref, k_out), f"Outputs do not match for K"


class _FakeCpGroup:
def __init__(self, size: int, rank: int) -> None:
self._size = size
self._rank = rank

def size(self) -> int:
return self._size

def rank(self) -> int:
return self._rank


def test_apply_rotary_pos_emb_thd_uses_offsets_for_local_packed_shards():
config = TransformerConfig(
num_layers=1, hidden_size=4, num_attention_heads=1, apply_rope_fusion=False
)
cp_group = _FakeCpGroup(size=1, rank=0)

t = torch.tensor([[[1.0, 2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0, 9.0]]])
freqs = (torch.arange(16, dtype=torch.float32).view(4, 1, 1, 4) + 1.0) / 10.0
cu_seqlens = torch.tensor([0, 2, 3], dtype=torch.int32)
offsets = torch.tensor([2, 1], dtype=torch.int32)

expected_freqs = torch.cat([freqs[2:4], freqs[1:2]], dim=0)
expected = _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
expected_freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
).squeeze(1)
legacy_expected = _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
torch.cat([freqs[:2], freqs[:1]], dim=0),
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
).squeeze(1)

output = apply_rotary_pos_emb(
t, freqs, config, cu_seqlens=cu_seqlens, cp_group=cp_group, offsets=offsets
)

assert torch.allclose(output, expected)
assert not torch.allclose(output, legacy_expected)


def test_apply_rotary_pos_emb_thd_falls_back_from_fusion_when_offsets_are_provided(monkeypatch):
monkeypatch.setattr(rope_utils, "fused_apply_rotary_pos_emb", object())

config = TransformerConfig(
num_layers=1, hidden_size=4, num_attention_heads=1, apply_rope_fusion=True
)
cp_group = _FakeCpGroup(size=1, rank=0)

t = torch.tensor([[[1.0, 2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0, 9.0]]])
freqs = (torch.arange(16, dtype=torch.float32).view(4, 1, 1, 4) + 1.0) / 10.0
cu_seqlens = torch.tensor([0, 2, 3], dtype=torch.int32)
offsets = torch.tensor([2, 1], dtype=torch.int32)

def _unexpected_fused(*args, **kwargs):
raise AssertionError("fused THD RoPE should not be used when offsets are provided")

monkeypatch.setattr(rope_utils, "fused_apply_rotary_pos_emb_thd", _unexpected_fused)

expected = _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
torch.cat([freqs[2:4], freqs[1:2]], dim=0),
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
).squeeze(1)

with pytest.warns(UserWarning, match="offsets are not supported by fused THD RoPE"):
output = apply_rotary_pos_emb(
t, freqs, config, cu_seqlens=cu_seqlens, cp_group=cp_group, offsets=offsets
)

assert torch.allclose(output, expected)


def test_apply_rotary_pos_emb_thd_converts_offsets_to_python_ints(monkeypatch):
config = TransformerConfig(
num_layers=1, hidden_size=4, num_attention_heads=1, apply_rope_fusion=False
)
cp_group = _FakeCpGroup(size=1, rank=0)

t = torch.tensor([[[1.0, 2.0, 3.0, 4.0]], [[2.0, 3.0, 4.0, 5.0]], [[6.0, 7.0, 8.0, 9.0]]])
freqs = (torch.arange(16, dtype=torch.float32).view(4, 1, 1, 4) + 1.0) / 10.0
cu_seqlens = torch.tensor([0, 2, 3], dtype=torch.int32)
offsets = torch.tensor([2, 1], dtype=torch.int32)

seen_offsets = []
original_get_freqs = rope_utils._get_thd_freqs_on_this_cp_rank

def _tracking_get_freqs(cp_rank, cp_size, x, freqs, offset=0):
seen_offsets.append(offset)
return original_get_freqs(cp_rank, cp_size, x, freqs, offset)

monkeypatch.setattr(rope_utils, "_get_thd_freqs_on_this_cp_rank", _tracking_get_freqs)

apply_rotary_pos_emb(
t, freqs, config, cu_seqlens=cu_seqlens, cp_group=cp_group, offsets=offsets
)

assert seen_offsets == [2, 1]
assert all(isinstance(offset, int) for offset in seen_offsets)
Loading