Skip to content

Commit 13f86a2

Browse files
committed
Add checks for MNNVL support for tests that use the API
1 parent fea55b7 commit 13f86a2

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/comm/test_trtllm_moe_alltoall.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616

1717
import pytest
1818
import torch
19-
19+
import pynvml
2020
from flashinfer.comm.mapping import Mapping
21+
from flashinfer.comm.mnnvl import MnnvlMemory
2122

2223
import flashinfer.comm.trtllm_moe_alltoall as trtllm_moe_alltoall
2324

25+
pynvml.nvmlInit()
26+
2427

2528
@pytest.fixture(autouse=True, scope="session")
2629
def setup_test_environment():
@@ -90,6 +93,10 @@ def make_payload(num_tokens, vector_dim, dtype):
9093
"num_tokens,vector_dim,num_experts,top_k",
9194
SINGLE_GPU_PARAMS,
9295
)
96+
@pytest.mark.skipif(
97+
not MnnvlMemory.supports_mnnvl(),
98+
reason="Mnnvl memory is not supported on this platform",
99+
)
93100
def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k):
94101
"""Test MOE alltoall communication on single GPU."""
95102
torch.cuda.set_device(0)
@@ -551,6 +558,10 @@ def test_moe_combine_multi_rank_single_gpu(
551558
)
552559

553560

561+
@pytest.mark.skipif(
562+
not MnnvlMemory.supports_mnnvl(),
563+
reason="Mnnvl memory is not supported on this platform",
564+
)
554565
def test_moe_workspace_size_per_rank():
555566
"""Test the workspace size per rank for the MoeAlltoAll operation."""
556567
ep_size = 8

0 commit comments

Comments
 (0)