Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 91 additions & 19 deletions megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

from typing import Optional

from megatron.core.utils import internal_api

try:
Expand Down Expand Up @@ -280,9 +282,11 @@ def init_hybrid_ep_buffer(
hidden_dim: int,
seq_len: int,
num_local_experts: int,
num_sms_dispatch_api: int,
num_sms_combine_api: int,
fp8_dispatch: bool,
num_sms_dispatch_api: Optional[int] = None,
num_sms_combine_api: Optional[int] = None,
num_blocks_permute: Optional[int] = None,
num_blocks_unpermute: Optional[int] = None,
fp8_dispatch: bool = False,
) -> None:
'''
Initialize the HybridEP buffer, including buffer allocation and metadata
Expand All @@ -301,23 +305,35 @@ def init_hybrid_ep_buffer(
Maximum sequence length of the input tensor.
num_local_experts (int):
Number of local experts.
num_sms_dispatch_api (int):
num_sms_dispatch_api (Optional[int]):
Number of SMs used by the dispatch API.
num_sms_combine_api (int):
num_sms_combine_api (Optional[int]):
Number of SMs used by the combine API.
num_blocks_permute (Optional[int]):
Number of blocks used by the permute part.
num_blocks_unpermute (Optional[int]):
Number of blocks used by the unpermute part.
fp8_dispatch (bool):
Whether to use FP8 communication during the dispatch phase.
'''
assert not fp8_dispatch, "HybridEP dispatcher does not support fp8 dispatch now"
global _hybrid_ep_buffer
kwargs = {}
if num_sms_dispatch_api is not None:
kwargs['num_sms_dispatch_api'] = num_sms_dispatch_api
if num_sms_combine_api is not None:
kwargs['num_sms_combine_api'] = num_sms_combine_api
if num_blocks_permute is not None:
kwargs['num_blocks_permute'] = num_blocks_permute
if num_blocks_unpermute is not None:
kwargs['num_blocks_unpermute'] = num_blocks_unpermute
_hybrid_ep_buffer = HybridEPBuffer(
group=group,
hidden_dim=hidden_dim,
max_num_of_tokens_per_rank=seq_len,
num_local_experts=num_local_experts,
use_fp8=fp8_dispatch,
num_sms_dispatch_api=num_sms_dispatch_api,
num_sms_combine_api=num_sms_combine_api,
**kwargs,
)


Expand All @@ -342,14 +358,34 @@ def forward(
probs,
group,
num_local_experts,
num_sms_dispatch_api=24,
num_sms_combine_api=24,
num_sms_dispatch_api=None,
num_sms_combine_api=None,
num_blocks_permute=None,
num_blocks_unpermute=None,
fused=False,
num_permuted_tokens=None,
pad_multiple=None,
):
'''
Forward pass of fused dispatch of the HybridEP backend
'''
if fused or num_blocks_permute is not None or num_blocks_unpermute is not None:
import inspect
import warnings

sig = inspect.signature(HybridEPBuffer.dispatch_with_permute)
if 'fuse_permute_dispatch' not in sig.parameters:
warnings.warn(
"Current DeepEP version does not support fused permute dispatch or "
"num_blocks_permute/num_blocks_unpermute. Falling back to unfused "
"HybridEP dispatch.",
UserWarning,
stacklevel=2,
)
fused = False
num_blocks_permute = None
num_blocks_unpermute = None

if _hybrid_ep_buffer is None:
seq_len, hidden_dim = x.shape[-2:]
fp8_dispatch = False # Currently, we do not support fp8 dispatch
Expand All @@ -360,6 +396,8 @@ def forward(
num_local_experts,
num_sms_dispatch_api,
num_sms_combine_api,
num_blocks_permute,
num_blocks_unpermute,
fp8_dispatch,
)
# If we provide the num_permuted_tokens, we do not need to use sync to
Expand All @@ -381,10 +419,12 @@ def forward(
pad_multiple=pad_multiple,
num_permuted_tokens=num_permuted_tokens,
non_blocking=non_blocking,
**({"fuse_permute_dispatch": fused} if fused else {}),
)

ctx.handle = handle
ctx.pad_multiple = pad_multiple
ctx.fused = fused
return (
dispatched_hidden,
dispatched_probs,
Expand All @@ -400,9 +440,26 @@ def backward(ctx, grad_x, grad_probs, grad_scaling_factor, grad_tokens_per_exper
'''
handle = ctx.handle
combined_hidden, combined_probs = _hybrid_ep_buffer.combine_with_unpermute(
hidden=grad_x, probs=grad_probs, handle=handle, pad_multiple=ctx.pad_multiple
hidden=grad_x,
probs=grad_probs,
handle=handle,
pad_multiple=ctx.pad_multiple,
**({"fuse_unpermute_combine": ctx.fused} if ctx.fused else {}),
)
return (
combined_hidden,
None,
combined_probs,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
return combined_hidden, None, combined_probs, None, None, None, None, None, None, None


@internal_api
Expand All @@ -412,16 +469,20 @@ class HybridEPCombine(torch.autograd.Function):
'''

@staticmethod
def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None):
def forward(ctx, x, handle, num_permuted_tokens=None, pad_multiple=None, fused=False):
'''
Forward pass of fused combine of the HybridEP backend
'''
combined_hidden, _ = _hybrid_ep_buffer.combine_with_unpermute(
hidden=x, handle=handle, pad_multiple=pad_multiple
hidden=x,
handle=handle,
pad_multiple=pad_multiple,
**({"fuse_unpermute_combine": fused} if fused else {}),
)
ctx.handle = handle
ctx.pad_multiple = pad_multiple
ctx.num_permuted_tokens = num_permuted_tokens
ctx.fused = fused
return combined_hidden

@staticmethod
Expand All @@ -436,6 +497,7 @@ def backward(ctx, grad_x):
handle=handle,
pad_multiple=ctx.pad_multiple,
num_permuted_tokens=ctx.num_permuted_tokens,
**({"fuse_permute_dispatch": ctx.fused} if ctx.fused else {}),
)
return dispatched_hidden, None, None, None, None

Expand All @@ -449,8 +511,11 @@ def hybrid_ep_dispatch(
probs,
group,
num_local_experts,
num_sms_dispatch_api=24,
num_sms_combine_api=24,
num_sms_dispatch_api=None,
num_sms_combine_api=None,
num_blocks_permute=None,
num_blocks_unpermute=None,
fused=False,
num_permuted_tokens=None,
pad_multiple=None,
):
Expand All @@ -469,10 +534,14 @@ def hybrid_ep_dispatch(
Process group used for communication.
num_local_experts (int):
Number of local experts.
num_sms_dispatch_api (int):
num_sms_dispatch_api (Optional[int]):
Number of SMs used by the dispatch API.
num_sms_combine_api (int):
num_sms_combine_api (Optional[int]):
Number of SMs used by the combine API.
num_blocks_permute (Optional[int]):
Number of blocks used by the permute part.
num_blocks_unpermute (Optional[int]):
Number of blocks used by the unpermute part.
num_permuted_tokens (int):
Number of tokens after permute. HybridEP uses this to allocate buffers.
If not provided, HybridEP obtains the size from a GPU tensor,
Expand All @@ -489,12 +558,15 @@ def hybrid_ep_dispatch(
num_local_experts,
num_sms_dispatch_api,
num_sms_combine_api,
num_blocks_permute,
num_blocks_unpermute,
fused,
num_permuted_tokens,
pad_multiple,
)

@internal_api
def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple):
def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple, fused=False):
'''
Perform fused combine operation for unpermute + combine a2a + unpermute
using the HybridEP backend
Expand All @@ -511,7 +583,7 @@ def hybrid_ep_combine(x, handle, num_permuted_tokens, pad_multiple):
The alignment multiple required for FP8 GEMM. If not provided, no padding
is performed.
'''
return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple)
return HybridEPCombine.apply(x, handle, num_permuted_tokens, pad_multiple, fused)

else:
hybrid_ep_dispatch = None
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,11 @@ def dispatch(
num_local_experts=self.num_local_experts,
num_sms_dispatch_api=self.config.moe_hybridep_num_sms,
num_sms_combine_api=self.config.moe_hybridep_num_sms,
num_blocks_permute=self.config.moe_hybridep_num_blocks_permute,
num_blocks_unpermute=self.config.moe_hybridep_num_blocks_unpermute,
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
fused=self.config.moe_permute_fusion_into_hybridep,
)
)

Expand All @@ -1071,6 +1074,7 @@ def combine(
handle=self.handle,
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
fused=self.config.moe_permute_fusion_into_hybridep,
)
# Release the used handle/num_permuted_tokens which could change in each iteration.
# For drop_and_pad mode, we don't need to reset the num_permuted_tokens and
Expand Down Expand Up @@ -1357,8 +1361,8 @@ def __init__(

self.num_local_experts = num_local_experts
self.local_expert_indices = local_expert_indices
assert self.tp_size * self.ep_size > 1, "Flex token dispatcher requires TPxEP > 1"
if self.config.moe_flex_dispatcher_backend == "deepep":
assert self.tp_size * self.ep_size > 1, "DeepEP dispatcher requires TPxEP > 1"
self._comm_manager = _DeepepManager(
group=self.tp_ep_group,
num_local_experts=self.num_local_experts,
Expand Down
19 changes: 16 additions & 3 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ class TransformerConfig(ModelParallelConfig):
Options are "deepep" and "hybridep". Currently only "hybridep" backend supports
the MNNVL case."""

moe_permute_fusion_into_hybridep: bool = False
"""Fuse token rearrangement ops during token dispatching for HybridEP."""

moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""

Expand Down Expand Up @@ -803,9 +806,19 @@ class TransformerConfig(ModelParallelConfig):
moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""

moe_hybridep_num_sms: int = 16
"""Number of SMs to use for HybridEP. In pure NVL scenarios,
16 SMs can generally achieve good bandwidth."""
moe_hybridep_num_sms: Optional[int] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; could we add the defaults used in the doc string here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not tend to do this, as the default number of SMs used on the hybrid EP side may change.

"""Number of SMs to use for HybridEP. None uses the default from DeepEP.
In pure NVL scenarios, 16 SMs can generally achieve good bandwidth."""

moe_hybridep_num_blocks_permute: Optional[int] = None
"""Number of CUDA thread blocks for the permute part in HybridEP.
When permute_fusion_into_hybridep is True, this sets the number
of SMs for the permute part (only 1 block per SM)."""

moe_hybridep_num_blocks_unpermute: Optional[int] = None
"""Number of CUDA thread blocks for the unpermute part in HybridEP.
When permute_fusion_into_hybridep is True, this sets the number
of SMs for the unpermute part (only 1 block per SM)."""

##################
# Context Parallel
Expand Down
5 changes: 4 additions & 1 deletion tests/unit_tests/models/test_mamba_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@
"moe_ffn_hidden_size": 1856,
"moe_flex_dispatcher_backend": "deepep",
"moe_grouped_gemm": True,
"moe_hybridep_num_sms": 16,
"moe_hybridep_num_sms": None,
"moe_hybridep_num_blocks_permute": None,
"moe_hybridep_num_blocks_unpermute": None,
"moe_input_jitter_eps": None,
"moe_latent_size": None,
"moe_layer_freq": 1,
Expand All @@ -170,6 +172,7 @@
"moe_pad_experts_for_cuda_graph_inference": False,
"moe_per_layer_logging": False,
"moe_permute_fusion": False,
"moe_permute_fusion_into_hybridep": False,
"moe_router_bias_update_rate": 0.001,
"moe_router_dtype": "fp64",
"moe_router_enable_expert_bias": True,
Expand Down
Loading
Loading