Skip to content

GroupedLinear performance on Mixtral-8x7B at EP=2 #2939

@faradawn

Description

@faradawn

Summary

transformer_engine.pytorch.GroupedLinear is ~2× slower than a Python loop of F.linear calls

8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.

Batch Loop (F.linear × 4) GroupedLinear Ratio
1 240 ms 449 ms 1.87×
2 242 ms 474 ms 1.96×
4 242 ms 511 ms 2.11×
8 306 ms 617 ms 2.02×
16 506 ms OOM

Nsys profiling

Image

Huggingface
1 layer (+5.726 ms)
Attention (+1.487 ms)
MLP (+3.451 ms)

Loop (+952.678 μs)
Transformer engine
1 Layer (+6.171 ms)
Attention (+2.351 ms)
MLP (+3.486 ms)

TE: grouped (+966.614 μs)

Repro (in PR #2642)

cd docs/examples/te_mixtral

# Loop (Tier 2)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
    --improvement 2 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
    --warmup-steps 10 --train-steps 200

# GroupedLinear (Tier 3)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
    --improvement 3 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
    --warmup-steps 10 --train-steps 200

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions