Summary
transformer_engine.pytorch.GroupedLinear is ~2× slower than a Python loop of F.linear calls
8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.
| Batch |
Loop (F.linear × 4) |
GroupedLinear |
Ratio |
| 1 |
240 ms |
449 ms |
1.87× |
| 2 |
242 ms |
474 ms |
1.96× |
| 4 |
242 ms |
511 ms |
2.11× |
| 8 |
306 ms |
617 ms |
2.02× |
| 16 |
506 ms |
OOM |
— |
Nsys profiling
Huggingface
1 layer (+5.726 ms)
Attention (+1.487 ms)
MLP (+3.451 ms)
Loop (+952.678 μs)
Transformer engine
1 Layer (+6.171 ms)
Attention (+2.351 ms)
MLP (+3.486 ms)
TE: grouped (+966.614 μs)
Repro (in PR #2642)
cd docs/examples/te_mixtral
# Loop (Tier 2)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
--improvement 2 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
--warmup-steps 10 --train-steps 200
# GroupedLinear (Tier 3)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
--improvement 3 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
--warmup-steps 10 --train-steps 200
Summary
transformer_engine.pytorch.GroupedLinearis ~2× slower than a Python loop ofF.linearcalls8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.
F.linear× 4)GroupedLinearNsys profiling
Huggingface
1 layer (+5.726 ms)
Attention (+1.487 ms)
MLP (+3.451 ms)
Loop (+952.678 μs)
Transformer engine
1 Layer (+6.171 ms)
Attention (+2.351 ms)
MLP (+3.486 ms)
TE: grouped (+966.614 μs)
Repro (in PR #2642)