Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds elastic Expert Parallelism (EP) behavior to the dispatch/combine operations, enabling jobs to continue making progress when some ranks become unresponsive. The implementation includes timeout-based detection of unresponsive ranks, active rank tracking, and graceful degradation when ranks fail. Additionally, the PR corrects the wall-clock frequency API naming from MHz to KHz for accuracy.
Changes:
- Added elastic EP state management with
active_rankstensor andtimeout_usparameter throughout the dispatch/combine API stack - Implemented device-side timeout and rank activity checking utilities with atomic operations for thread safety
- Renamed wall-clock frequency API from MHz to KHz (correcting the naming to match actual HIP API behavior)
- Added comprehensive elastic EP test coverage and example integration
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| python/mori/ops/dispatch_combine.py | Added active_ranks and timeout_us parameters to dispatch*/combine* methods |
| python/mori/kernel_profiler/init.py | Updated to use renamed KHz wall-clock API |
| src/pybind/mori.cpp | Added MaybeUpdateElasticState helper and parameter validation; updated all dispatch/combine/recv functions |
| include/mori/ops/dispatch_combine/dispatch_combine.hpp | Added wallClockRateKHz, activeRanks, timeoutTicks fields and SetElasticState method |
| include/mori/utils/hip_helper.hpp | Renamed wall-clock API functions from MHz to KHz |
| include/mori/core/transport/p2p/device_primitives.hpp | Added elastic EP device utilities: IsRankActive, MarkRankInactive, WaitUntil*OrTimeout |
| src/ops/dispatch_combine/dispatch_combine.cpp | Initialize wallClockRateKHz in constructor |
| src/ops/dispatch_combine/intranode.hpp | Added elastic checks and timeout handling in dispatch/combine/barrier operations |
| src/ops/dispatch_combine/internode_v1.cpp | Added elastic checks throughout internode dispatch/combine/sync operations |
| src/ops/dispatch_combine/low_latency_async.cpp | Added elastic checks in async low-latency dispatch/combine paths |
| tests/python/ops/test_dispatch_combine.py | Added test_dispatch_combine_elastic_ep with dropout simulation |
| examples/ops/dispatch_combine/test_dispatch_combine_internode.py | Extended with drop_rank and timeout_us support for elastic testing |
Comments suppressed due to low confidence (2)
python/mori/ops/dispatch_combine.py:219
- The docstrings for
dispatchandcombinemethods don't document the newactive_ranksandtimeout_usparameters. Please add documentation for these parameters explaining:
active_ranks: Optional int32 CUDA tensor of shape (world_size,) indicating which ranks are active (1) or inactive (0). Used for elastic EP behavior.timeout_us: Optional timeout in microseconds for detecting unresponsive ranks. When a rank doesn't respond within this time, it is marked inactive in the active_ranks tensor. Use -1 or None to disable timeout.
"""Dispatch tokens to experts based on top-k indices.
Args:
input: Input token tensor.
weights: Token weights for each expert.
scales: Quantization scales (optional).
indices: Top-k expert indices.
block_num: Override config.block_num if > 0.
warp_per_block: Override config.warp_num_per_block if > 0.
"""
python/mori/ops/dispatch_combine.py:297
- The docstring for the
combinemethod doesn't document the newactive_ranksandtimeout_usparameters. Please add documentation for these parameters explaining:
active_ranks: Optional int32 CUDA tensor of shape (world_size,) indicating which ranks are active (1) or inactive (0). Used for elastic EP behavior.timeout_us: Optional timeout in microseconds for detecting unresponsive ranks. When a rank doesn't respond within this time, it is marked inactive in the active_ranks tensor. Use -1 or None to disable timeout.
"""Combine tokens from experts back to original positions.
Args:
input: Expert output tensor.
weights: Token weights for weighted combination.
indices: Top-k expert indices.
block_num: Override config.block_num if > 0.
warp_per_block: Override config.warp_num_per_block if > 0.
use_external_inp_buf: Override config.use_external_inp_buf if >= 0.
0 = use zero-copy (registered combine input buffer),
1 = use external input buffer (non-zero-copy).
call_reset: Whether to call reset after combine.
"""
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL: | ||
| ret = op.dispatch_send(token, weights, scales, indices) | ||
| op.dispatch_recv() | ||
| else: | ||
| ret = op.dispatch(token, weights, scales, indices) | ||
| ret = op.dispatch(token, weights, scales, indices, **kwargs) | ||
| return ret |
There was a problem hiding this comment.
For AsyncLL kernel type, the elastic parameters are not passed to dispatch_send/dispatch_recv. While the AsyncLL kernels do support elastic EP (they have IsRankActive checks), the dispatch_send/combine_send/dispatch_recv/combine_recv methods already accept active_ranks and timeout_us parameters. Consider passing these parameters through the kwargs for AsyncLL as well to enable elastic EP support for this kernel type in the example.
| if op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL: | ||
| ret = op.combine_send(token, weights, indices) | ||
| op.combine_recv() | ||
| else: | ||
| ret = op.combine(token, weights, indices) | ||
| ret = op.combine(token, weights, indices, **kwargs) |
There was a problem hiding this comment.
Similar to run_dispatch, for AsyncLL kernel type the elastic parameters are not passed to combine_send/combine_recv. Consider passing these parameters through the kwargs for AsyncLL as well to enable elastic EP support for this kernel type in the example.
Motivation
This PR adds elastic EP behavior to
dispatch_combineso jobs can continue making progress when some rank/node become unresponsive, instead of hanging on cross-rank synchronization.The goal is to improve robustness for internode MoE dispatch/combine under partial-failure scenarios while keeping the existing behavior unchanged when elastic mode is not enabled.
Technical Details
active_ranks) and timeout (timeout_us) plumbing from Python API -> pybind -> C++ handle -> device kernels.device_primitives(rank active checks, rank deactivation, wait-with-timeout helpers) and applied them across internode/intranode/low-latency dispatch+combine synchronization points.EpDispatchCombineHandleto store elastic state and wall-clock rate, and converted timeout from microseconds to device wall-clock ticks.get_cur_device_wall_clock_freq_mhz->get_cur_device_wall_clock_freq_khzEpDispatchCombineOpinterfaces (dispatch*/combine*) with optionalactive_ranksandtimeout_usparameters.tests/python/ops/test_dispatch_combine.py(test_dispatch_combine_elastic_ep)drop_rank) and timeout-based inactive rank handling.Test Plan
tests/python/ops/test_dispatch_combine.py::test_dispatch_combinetests/python/ops/test_dispatch_combine.py::test_dispatch_combine_elastic_epexamples/ops/dispatch_combine/test_dispatch_combine_internode.pywith:drop_rank=-1)drop_rank=<rank_id>,timeout_us=<value>)Test Result