Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 173 additions & 35 deletions tests/unit_tests/models/test_mamba_model_expert_parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Tests for full MambaModel inference with expert-parallel batch dimension sync.

When expert parallelism > 1 with strict matching (hybrid models), batch
dimensions are MAX-reduced across EP ranks. Different EP ranks can be in
dimensions are MAX-reduced across EP ranks. Different EP ranks can be in
one of four request states:

- NONE: 0 requests (dummy rank, uses is_expert_parallel_dummy_cuda_graph_step)
Expand All @@ -16,6 +16,8 @@
synchronization path (strict matching + MAX-reduce on batch dimensions).
"""

import itertools

import pytest
import torch
import torch.distributed as dist
Expand All @@ -29,6 +31,7 @@
from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord, delete_cuda_graphs
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import is_fa_min_version
from tests.unit_tests.test_utilities import Utils
Expand All @@ -41,16 +44,28 @@

ALL_STATES = [NONE, DECODE, PREFILL, MIXED]

# Fixed expert-parallel size. When world_size > _EP_SIZE the remaining
# ranks form data-parallel replicas, each running the same EP combo
# independently.
_EP_SIZE = 4

# Combinatorial sweep: unordered combinations with repetition of ALL_STATES
# across the EP ranks. Since rank assignment is symmetric (shuffling ranks
# with the same multiset of states is not a distinct configuration), we use
# combinations_with_replacement rather than the full Cartesian product.
# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases.
_STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE))

# Batch dimensions used to set up each non-dummy state via
# add_dummy_requests_for_cudagraph_capture. These are intentionally small
# add_dummy_requests_for_cudagraph_capture. These are intentionally small
# to keep the tests fast while still exercising the EP padding logic.
_STATE_DIMS = {
# 2 decode requests, 1 token each -> 2 tokens total
DECODE: InferenceBatchDimensions(token_count=2, prefill_req_count=0, decode_req_count=2),
# 2 prefill requests with 16 tokens each -> 32 tokens total
PREFILL: InferenceBatchDimensions(token_count=32, prefill_req_count=2, decode_req_count=0),
# 2 decode (2 tokens) + 1 prefill (30 tokens) = 32 tokens
MIXED: InferenceBatchDimensions(token_count=32, prefill_req_count=1, decode_req_count=2),
# 4 decode (4 tokens) + 2 prefill (60 tokens) = 64 tokens
MIXED: InferenceBatchDimensions(token_count=64, prefill_req_count=2, decode_req_count=4),
}


Expand All @@ -69,20 +84,26 @@ def setup_method(self, method):
pytest.skip(reason, allow_module_level=True)
if not is_fa_min_version("2.7.3"):
pytest.skip("need flash-attn >= 2.7.3 for dynamic batching", allow_module_level=True)
if Utils.world_size < 2:
pytest.skip("EP test requires at least 2 GPUs", allow_module_level=True)
if Utils.world_size < _EP_SIZE:
pytest.skip(f"EP test requires at least {_EP_SIZE} GPUs", allow_module_level=True)
if Utils.world_size % _EP_SIZE != 0:
pytest.skip(
f"world_size ({Utils.world_size}) must be divisible by EP size ({_EP_SIZE})",
allow_module_level=True,
)

Utils.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
expert_model_parallel_size=Utils.world_size,
expert_model_parallel_size=_EP_SIZE,
)

def teardown_method(self, method):
delete_cuda_graphs()
Utils.destroy_model_parallel()

def _build_model(self):
model_parallel_cuda_manual_seed(123)
model_parallel_cuda_manual_seed(123, inference_rng_tracker=True, force_reset_rng=True)
config = TransformerConfig(
num_layers=3,
mtp_hybrid_override_pattern="ME*",
Expand All @@ -94,6 +115,7 @@ def _build_model(self):
attention_backend=AttnBackend.fused,
num_moe_experts=2,
moe_token_dispatcher_type="alltoall",
cuda_graph_impl="local",
)
model = MambaModel(
config=config,
Expand All @@ -107,7 +129,12 @@ def _build_model(self):
return model

def _build_context(
self, model, *, num_cuda_graphs=16, use_cuda_graphs_for_non_decode_steps=True
self,
model,
*,
num_cuda_graphs=16,
use_cuda_graphs_for_non_decode_steps=True,
max_requests=None,
):
mamba_config = MambaInferenceStateConfig.from_model(model)
return DynamicInferenceContext(
Expand All @@ -120,6 +147,7 @@ def _build_context(
mamba_inference_state_config=mamba_config,
num_cuda_graphs=num_cuda_graphs,
use_cuda_graphs_for_non_decode_steps=use_cuda_graphs_for_non_decode_steps,
max_requests=max_requests,
),
)

Expand All @@ -141,6 +169,34 @@ def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label):
f"got {tuple(out.shape)}"
)

@staticmethod
def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label):
"""Assert that CUDA graphs were (or were not) recorded and replayed
during the preceding model.forward() call.

The inference path in CudaGraphManager records each layer's runner
into _CudagraphGlobalRecord.cudagraph_inference_record the first time
a graph is captured. A non-empty record with fwd_graph_recorded=True
on every runner confirms the graph was both recorded and replayed.
"""
record = _CudagraphGlobalRecord.cudagraph_inference_record
if expect_replayed:
assert len(record) > 0, (
f"Rank {rank} ({label}): expected CUDA graphs to be recorded and "
f"replayed, but cudagraph_inference_record is empty"
)
for runner, _graph_type, _args, _kwargs in record:
assert runner.fwd_graph_recorded, (
f"Rank {rank} ({label}): CUDA graph runner for "
f"{runner.base_module.__class__.__name__} (layer "
f"{runner.base_module.layer_number}) was not recorded"
)
else:
assert len(record) == 0, (
f"Rank {rank} ({label}): expected no CUDA graph replay, "
f"but cudagraph_inference_record has {len(record)} entries"
)

def _assert_dummy_forward_shape(self, model, rank):
"""Run model.forward with a single dummy token (no inference context),
mirroring the real engine's dummy_forward fallback, and verify the
Expand All @@ -156,31 +212,26 @@ def _assert_dummy_forward_shape(self, model, rank):
)

# ------------------------------------------------------------------
# test_ep_state_cross_product: full 4x4 matrix with mixed CUDA graphs
# test_ep_state_cross_product: combinatorial sweep with mixed CUDA graphs
# ------------------------------------------------------------------

@pytest.mark.parametrize(
"even_state,odd_state",
[(a, b) for a in ALL_STATES for b in ALL_STATES],
ids=[f"even={a}_odd={b}" for a in ALL_STATES for b in ALL_STATES],
)
@pytest.mark.parametrize("rank_states", _STATE_COMBOS, ids=[",".join(s) for s in _STATE_COMBOS])
@pytest.mark.internal
@torch.inference_mode()
def test_ep_state_cross_product(self, even_state, odd_state):
"""Test all 16 combinations of EP rank request states.
def test_ep_state_cross_product(self, rank_states):
"""Test all combinatorial (unordered, with repetition) assignments of
the four request states across EP ranks.

The context is built with use_cuda_graphs_for_non_decode_steps=True,
so the CUDA graph list contains decode-only, mixed, and prefill-only
graphs. After the EP all-reduce in match_graph_config, every rank
graphs. After the EP all-reduce in match_graph_config, every rank
(including dummy ranks) should always find a matching graph.

State setup uses add_dummy_requests_for_cudagraph_capture to populate
the context directly with the desired request configuration (including
mamba state allocation with zeroed conv/ssm states). No forward
passes or request lifecycle transitions are needed.
the context directly with the desired request configuration.
"""
rank = dist.get_rank()
my_state = even_state if rank % 2 == 0 else odd_state
ep_rank = parallel_state.get_expert_model_parallel_rank()
my_state = rank_states[ep_rank]
is_dummy = my_state == NONE

model = self._build_model()
Expand All @@ -201,9 +252,9 @@ def test_ep_state_cross_product(self, even_state, odd_state):
# ranks whose EP-adjusted dimensions inherit prefill/decode counts
# from peers — must find a matching graph.
assert ctx.using_cuda_graph_this_step(), (
f"Rank {rank} (state={my_state}): expected a CUDA graph match "
f"EP rank {ep_rank} (state={my_state}): expected a CUDA graph match "
f"with use_cuda_graphs_for_non_decode_steps=True "
f"(even={even_state}, odd={odd_state})"
f"(rank_states={rank_states})"
)

# All EP ranks must agree on padded token count.
Expand All @@ -217,10 +268,13 @@ def test_ep_state_cross_product(self, even_state, odd_state):
assert tc_max.item() == tc_min.item(), (
f"Padded token count mismatch across EP ranks: "
f"min={tc_min.item()}, max={tc_max.item()} "
f"(even={even_state}, odd={odd_state})"
f"(rank_states={rank_states})"
)

self._assert_dynamic_inference_shape(model, ctx, rank, my_state)
self._assert_dynamic_inference_shape(model, ctx, ep_rank, my_state)
self._assert_cuda_graphs_were_replayed(
True, ep_rank, f"state={my_state}, rank_states={rank_states}"
)

# ------------------------------------------------------------------
# test_dummy_bailout_with_decode_only_cuda_graphs: dedicated bail-out
Expand All @@ -236,10 +290,10 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state):
are available.

With use_cuda_graphs_for_non_decode_steps=False, the CUDA graph list
contains only decode-only graphs. When any EP rank has prefill
contains only decode-only graphs. When any EP rank has prefill
requests, adjust_batch_dims_for_expert_parallelism returns None
(forcing eager mode), and match_graph_config returns None for all
ranks. A dummy rank then bails out of initialize_attention_state
ranks. A dummy rank then bails out of initialize_attention_state
early (padded_batch_dimensions is not set).

This test verifies that:
Expand All @@ -249,10 +303,10 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state):
output (padded_batch_dimensions is computed via the non-graph
fallback path).

Even ranks are dummy; odd ranks have the parametrized peer_state.
Even EP ranks are dummy; odd EP ranks have the parametrized peer_state.
"""
rank = dist.get_rank()
is_even = rank % 2 == 0
ep_rank = parallel_state.get_expert_model_parallel_rank()
is_even = ep_rank % 2 == 0

model = self._build_model()
ctx = self._build_context(model, use_cuda_graphs_for_non_decode_steps=False)
Expand All @@ -270,14 +324,98 @@ def test_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state):
# Verify: no rank should have matched a CUDA graph because the
# peer has prefill but only decode graphs are available.
assert not ctx.using_cuda_graph_this_step(), (
f"Rank {rank}: expected no CUDA graph match with "
f"EP rank {ep_rank}: expected no CUDA graph match with "
f"decode-only graphs and peer_state={peer_state}"
)

if is_even:
# Dummy rank bailed out — exercise the eager fallback.
self._assert_dummy_forward_shape(model, rank)
self._assert_dummy_forward_shape(model, ep_rank)
else:
# Non-dummy rank: padded_batch_dimensions is set via the
# non-graph fallback path in initialize_attention_state.
self._assert_dynamic_inference_shape(model, ctx, rank, peer_state)
self._assert_dynamic_inference_shape(model, ctx, ep_rank, peer_state)
self._assert_cuda_graphs_were_replayed(
False, ep_rank, f"decode-only graphs, peer_state={peer_state}"
)

# ------------------------------------------------------------------
# test_mixed_cuda_graphs_tokens_exceed_max_requests: eager fallback
# ------------------------------------------------------------------

@pytest.mark.parametrize(
"peer_state", [PREFILL, MIXED], ids=[f"peer={s}" for s in [PREFILL, MIXED]]
)
@pytest.mark.internal
@torch.inference_mode()
def test_mixed_cuda_graphs_tokens_exceed_max_requests(self, peer_state):
"""Verify eager fallback when mixed CUDA graphs are allowed but
a rank's token count exceeds the CUDA graph capacity.

With use_cuda_graphs_for_non_decode_steps=True, the CUDA graph
list includes mixed and prefill-only graphs. However, the
maximum CUDA graph token capacity is bounded by max_requests
(specifically, max_requests * (num_speculative_tokens + 1)).

When one EP rank has a token count exceeding this capacity, no
CUDA graph can accommodate the EP-adjusted dimensions.
match_graph_config returns None for all ranks, forcing eager
mode globally. This test verifies that:
- No rank matches a CUDA graph (eager mode is forced).
- Dummy ranks bail out and produce correct shapes via the
eager dummy_forward path.
- Non-dummy ranks produce correct shapes via the eager
padded_batch_dimensions fallback.
"""
ep_rank = parallel_state.get_expert_model_parallel_rank()
is_even = ep_rank % 2 == 0

model = self._build_model()

# Use a small max_requests so that the CUDA graph capacity
# (max_requests tokens with no speculative decoding) is easily
# exceeded by a prefill-heavy rank.
small_max_requests = 16
ctx = self._build_context(
model, use_cuda_graphs_for_non_decode_steps=True, max_requests=small_max_requests
)

# Even EP ranks are dummy (no requests). Odd EP ranks get a state
# whose token count exceeds small_max_requests.
overflow_token_count = small_max_requests + 16 # 32 tokens > 16 capacity
overflow_dims = {
PREFILL: InferenceBatchDimensions(
token_count=overflow_token_count, prefill_req_count=2, decode_req_count=0
),
MIXED: InferenceBatchDimensions(
token_count=overflow_token_count, prefill_req_count=1, decode_req_count=2
),
}

if not is_even:
ctx.add_dummy_requests_for_cudagraph_capture(overflow_dims[peer_state])

# Initialize attention state (EP collective).
if is_even:
ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=True)
else:
ctx.initialize_attention_state()

# No rank should have matched a CUDA graph — the EP-adjusted
# token count exceeds every graph's capacity.
assert not ctx.using_cuda_graph_this_step(), (
f"EP rank {ep_rank}: expected no CUDA graph match when token count "
f"({overflow_token_count}) exceeds max_requests ({small_max_requests}), "
f"peer_state={peer_state}"
)

if is_even:
# Dummy rank bailed out — exercise the eager fallback.
self._assert_dummy_forward_shape(model, ep_rank)
else:
# Non-dummy rank: padded_batch_dimensions is set via the
# eager fallback path. Verify shape correctness.
self._assert_dynamic_inference_shape(model, ctx, ep_rank, peer_state)
self._assert_cuda_graphs_were_replayed(
False, ep_rank, f"overflow tokens, peer_state={peer_state}"
)
Loading