Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_strategy(
payload_in_workspace: bool = False,
alltoall_result_do_sum: bool = True,
use_flashinfer: bool = False,
hidden_size: Optional[int] = None,
) -> Optional[Communication]:
"""
Create the best communication method for the given configuration
Expand All @@ -78,6 +79,9 @@ def create_strategy(
expert_size_per_partition: Number of experts per partition (required for DeepEP)
payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided)
alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided)
hidden_size: Actual MoE activation dimension (the A2A payload width).
For latent-MoE models this is moe_latent_size, not pretrained_config.hidden_size.
Falls back to pretrained_config.hidden_size when not provided.
# TODO: Need a way to indicate whether EPLB is enabled.

Returns:
Expand All @@ -89,7 +93,8 @@ def create_strategy(
"""
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
if hidden_size is None:
hidden_size = model_config.pretrained_config.hidden_size
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
Expand Down Expand Up @@ -120,6 +125,7 @@ def create_strategy(
payload_in_workspace,
alltoall_result_do_sum,
use_flashinfer,
hidden_size=hidden_size,
)

# Auto-selection: Try strategies in priority order using try-catch
Expand Down Expand Up @@ -218,6 +224,7 @@ def _create_forced_method(
payload_in_workspace: bool,
alltoall_result_do_sum: bool,
use_flashinfer: bool,
hidden_size: Optional[int] = None,
) -> Communication:
"""
Create a specific method (for debugging/testing)
Expand All @@ -228,7 +235,8 @@ def _create_forced_method(
"""
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
if hidden_size is None:
hidden_size = model_config.pretrained_config.hidden_size
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _create_comm_strategy_auto(self) -> Communication:
# Keep updated with more supported backends.
alltoall_result_do_sum=True,
use_flashinfer=self.use_flashinfer,
hidden_size=self.hidden_size,
)

def forward_impl(
Expand Down
Loading