Skip to content

Commit 394a56f

Browse files
committed
Add benchmark for all allreduce backend
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
1 parent ce71620 commit 394a56f

1 file changed

Lines changed: 338 additions & 11 deletions

File tree

tests/microbenchmarks/all_reduce.py

Lines changed: 338 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
from argparse import ArgumentParser
1718
from itertools import product
1819

@@ -26,10 +27,14 @@
2627
from cuda import cudart
2728

2829
import tensorrt_llm as tllm
30+
import tensorrt_llm.bindings.internal.userbuffers as ub
2931
from tensorrt_llm import Mapping
3032
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
33+
from tensorrt_llm._torch.custom_ops.userbuffers_custom_ops import \
34+
copy_to_userbuffers
3135
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
32-
Distributed)
36+
Distributed,
37+
userbuffers_allreduce_finalize)
3338
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
3439
from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size,
3540
nvtx_range)
@@ -52,6 +57,8 @@ def profile_allreduce(
5257
norm=None,
5358
scale=None,
5459
bias=None,
60+
allreduce_instance=None,
61+
dtype=None,
5562
):
5663

5764
allreduce_params = AllReduceParams(
@@ -63,7 +70,8 @@ def profile_allreduce(
6370
bias=bias,
6471
)
6572

66-
allreduce = AllReduce(mapping=mapping, strategy=strategy)
73+
allreduce = allreduce_instance or AllReduce(
74+
mapping=mapping, strategy=strategy, dtype=dtype)
6775

6876
def func(x, loop_num=inner_loop):
6977
for _ in range(loop_num):
@@ -273,6 +281,313 @@ def allreduce_benchmark(
273281
return df
274282

275283

284+
# ── nccl-tests style comprehensive benchmark (--benchmark mode) ──────────────
285+
286+
_STRATEGY_MAP = {
287+
"NCCL": AllReduceStrategy.NCCL,
288+
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC,
289+
"UB": AllReduceStrategy.UB,
290+
"ONESHOT": AllReduceStrategy.ONESHOT,
291+
"TWOSHOT": AllReduceStrategy.TWOSHOT,
292+
"AUTO": AllReduceStrategy.AUTO,
293+
"MNNVL": AllReduceStrategy.MNNVL,
294+
}
295+
_UB_STRATEGIES = {AllReduceStrategy.NCCL_SYMMETRIC, AllReduceStrategy.UB}
296+
_FUSION_MAP = {
297+
"NONE": AllReduceFusionOp.NONE,
298+
"RESIDUAL_RMS_NORM": AllReduceFusionOp.RESIDUAL_RMS_NORM,
299+
"RESIDUAL_RMS_NORM_QUANT_FP8": AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
300+
"RESIDUAL_RMS_NORM_QUANT_NVFP4": AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
301+
}
302+
303+
304+
def _fmt_size(nbytes):
305+
"""Format byte count as human-readable string (e.g. 256B, 4K, 1M, 2G)."""
306+
if nbytes < 1024:
307+
return f"{nbytes}B"
308+
elif nbytes < 1024**2:
309+
v = nbytes / 1024
310+
return f"{v:.0f}K" if nbytes % 1024 == 0 else f"{v:.1f}K"
311+
elif nbytes < 1024**3:
312+
v = nbytes / 1024**2
313+
return f"{v:.0f}M" if nbytes % (1024**2) == 0 else f"{v:.2f}M"
314+
else:
315+
v = nbytes / 1024**3
316+
return f"{v:.0f}G" if nbytes % (1024**3) == 0 else f"{v:.2f}G"
317+
318+
319+
def _profile_ub(mapping, dist, allreduce, fusion, input, residual, norm, scale,
320+
enable_cudagraph=False, inner_loop=200, outer_loop=10):
321+
"""Profile UB allreduce kernel only (copy_to_ub and finalize are one-shot)."""
322+
allreduce_params = AllReduceParams(
323+
fusion_op=fusion, residual=residual, norm_weight=norm.weight,
324+
eps=norm.variance_epsilon, scale=scale, bias=None)
325+
326+
# Copy input into user-buffer memory once (simulates matmul_to_ub in real flow)
327+
ub_input = copy_to_userbuffers(input)
328+
329+
def func(loop_num=inner_loop):
330+
for _ in range(loop_num):
331+
output = allreduce(ub_input, all_reduce_params=allreduce_params)
332+
return output
333+
334+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)]
335+
stops = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)]
336+
graph = torch.cuda.CUDAGraph()
337+
stream = torch.cuda.Stream()
338+
with torch.cuda.stream(stream):
339+
# warmup
340+
for _ in range(4):
341+
func(loop_num=1)
342+
if enable_cudagraph:
343+
with torch.cuda.graph(graph, stream=stream):
344+
func()
345+
dist.barrier()
346+
delay_kernel(20000, stream)
347+
torch.cuda.synchronize()
348+
for i in range(outer_loop):
349+
starts[i].record(stream)
350+
if enable_cudagraph:
351+
graph.replay()
352+
else:
353+
func()
354+
stops[i].record(stream)
355+
torch.cuda.synchronize()
356+
# Finalize once to sync (simulates userbuffers_allreduce_finalize in real flow)
357+
output = func(loop_num=1)
358+
userbuffers_allreduce_finalize(output[-1])
359+
runtimes = [starts[i].elapsed_time(stops[i]) for i in range(outer_loop)]
360+
return sorted(runtimes)[len(runtimes) // 2] / inner_loop * 1000.0
361+
362+
363+
def _print_table(fusion_name, strategy_names, rows, world_size):
364+
W_S, W_T, W_H, W_V, W_B = 10, 6, 6, 10, 16
365+
n = len(strategy_names)
366+
print(flush=True)
367+
print(f"# Fusion: {fusion_name} world_size={world_size} "
368+
f"algbw = size / time (GB/s)", flush=True)
369+
print("#", flush=True)
370+
fixed = f"{'size':>{W_S}} {'ntok':>{W_T}} {'hdim':>{W_H}}"
371+
sh = " ".join(f"{s:^{W_V * 2 + 2}}" for s in strategy_names)
372+
print(f"# {fixed} {sh} {'BEST':>{W_B}}", flush=True)
373+
pad = " " * (W_S + 2 + W_T + 2 + W_H)
374+
mh = " ".join(f"{'time(us)':>{W_V}} {'algbw':>{W_V}}" for _ in strategy_names)
375+
print(f"# {pad} {mh} {' ':>{W_B}}", flush=True)
376+
tw = 2 + W_S + 2 + W_T + 2 + W_H + 2 + n * (W_V * 2 + 2) + (n - 1) * 2 + 2 + W_B
377+
print("#" + "-" * (tw - 1), flush=True)
378+
for row in rows:
379+
prefix = (f" {row['size_human']:>{W_S}} "
380+
f"{row['num_tokens']:>{W_T}} "
381+
f"{row['hidden_size']:>{W_H}}")
382+
vals, best_name, best_time = [], "N/A", float("inf")
383+
for s in strategy_names:
384+
t, bw = row.get(f"{s}_time"), row.get(f"{s}_algbw")
385+
if t is not None:
386+
vals.append(f"{t:>{W_V}.2f} {bw:>{W_V}.2f}")
387+
if t < best_time:
388+
best_time, best_name = t, s
389+
else:
390+
vals.append(f"{'N/A':>{W_V}} {'N/A':>{W_V}}")
391+
print(f"{prefix} {' '.join(vals)} {best_name:>{W_B}}", flush=True)
392+
393+
394+
def allreduce_benchmark_all(
395+
dtype='bfloat16',
396+
test_range="256,268435456,2",
397+
explore_2d=False,
398+
enable_cudagraph=False,
399+
strategy_names=None,
400+
fusion_names=None,
401+
inner_loop=200,
402+
outer_loop=10,
403+
save_csv=None,
404+
):
405+
"""Comprehensive benchmark: one table per fusion, all strategies side by side."""
406+
import csv as csv_mod
407+
408+
world_size = tllm.mpi_world_size()
409+
rank = tllm.mpi_rank()
410+
local_rank = local_mpi_rank()
411+
gpus_per_node = local_mpi_size()
412+
413+
torch.cuda.set_device(local_rank)
414+
cudart.cudaSetDevice(local_rank)
415+
416+
mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size)
417+
logger.set_rank(mapping.rank)
418+
AutoTuner.get().setup_distributed_state(mapping)
419+
dist = Distributed.get(mapping)
420+
sm_version = get_sm_version()
421+
422+
if world_size == 1:
423+
raise RuntimeError("Benchmark requires mpi_world_size > 1")
424+
425+
torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
426+
elem_size = torch.finfo(torch_dtype).bits // 8
427+
428+
# Enable MNNVL testing on single-node (bypasses multi-node NVLink check)
429+
os.environ["TLLM_TEST_MNNVL"] = "1"
430+
431+
# strategies
432+
if strategy_names is None:
433+
strategy_names = ["NCCL", "NCCL_SYMMETRIC", "UB", "ONESHOT",
434+
"TWOSHOT", "AUTO", "MNNVL"]
435+
strategies = [_STRATEGY_MAP[s] for s in strategy_names]
436+
437+
# fusions
438+
if fusion_names is None:
439+
fusion_names = list(_FUSION_MAP.keys())
440+
fusions = []
441+
for f in fusion_names:
442+
fop = _FUSION_MAP[f]
443+
if fop == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 and sm_version < 100:
444+
if rank == 0:
445+
print(f"[WARN] {f} requires SM100+, skipping.", flush=True)
446+
continue
447+
fusions.append((f, fop))
448+
449+
# shapes
450+
if explore_2d:
451+
num_tokens_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
452+
hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192]
453+
shape_list = list(product(num_tokens_list, hidden_size_list))
454+
else:
455+
min_bytes, max_bytes, ratio = [int(i) for i in test_range.split(",")]
456+
shape_list = []
457+
nbytes = min_bytes
458+
while nbytes <= max_bytes:
459+
total_elems = nbytes // elem_size
460+
if total_elems <= 4096:
461+
shape_list.append((1, max(total_elems, 1)))
462+
else:
463+
shape_list.append((total_elems // 4096, 4096))
464+
nbytes *= ratio
465+
466+
# init user-buffers
467+
need_ub = bool(_UB_STRATEGIES & set(strategies))
468+
if need_ub:
469+
if ub.ub_supported():
470+
max_elems = max(s[0] * s[1] for s in shape_list)
471+
ub.initialize_userbuffers_manager(
472+
world_size, 1, 1, rank, torch.cuda.device_count(),
473+
max_elems * elem_size)
474+
else:
475+
if rank == 0:
476+
print("[WARN] ub not supported, skipping UB-based strategies.", flush=True)
477+
strategies = [s for s in strategies if s not in _UB_STRATEGIES]
478+
strategy_names = [s.name for s in strategies]
479+
480+
# create AllReduce instances
481+
ar_instances = {}
482+
for strat in strategies:
483+
try:
484+
ar_instances[strat] = AllReduce(mapping=mapping, strategy=strat, dtype=torch_dtype)
485+
except Exception as e:
486+
if rank == 0:
487+
print(f"[WARN] Cannot init {strat.name}: {e}", flush=True)
488+
strategies = [s for s in strategies if s in ar_instances]
489+
strategy_names = [s.name for s in strategies]
490+
491+
max_workspace = CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size)
492+
493+
if rank == 0:
494+
print(f"\n{'=' * 80}", flush=True)
495+
print(" TRT-LLM AllReduce Benchmark", flush=True)
496+
print(f" world_size={world_size} dtype={dtype} SM={sm_version}"
497+
f" cudagraph={enable_cudagraph}"
498+
f" inner={inner_loop} outer={outer_loop}", flush=True)
499+
print(f" Strategies : {', '.join(strategy_names)}", flush=True)
500+
print(f" Fusions : {', '.join(f for f, _ in fusions)}", flush=True)
501+
print(f"{'=' * 80}", flush=True)
502+
503+
csv_rows = []
504+
505+
for fusion_name, fusion_op in fusions:
506+
table_rows = []
507+
for num_tokens, hidden_size in shape_list:
508+
msg_bytes = num_tokens * hidden_size * elem_size
509+
inp = torch.ones((num_tokens, hidden_size), dtype=torch_dtype, device="cuda")
510+
res = torch.randn_like(inp)
511+
norm = RMSNorm(hidden_size=hidden_size, dtype=torch_dtype, eps=1e-5).cuda()
512+
norm.weight.data.copy_(torch.randn((hidden_size,), dtype=torch_dtype, device="cuda"))
513+
scale = torch.tensor(1.0, dtype=torch.float32).cuda()
514+
515+
row = dict(size_human=_fmt_size(msg_bytes), num_tokens=num_tokens,
516+
hidden_size=hidden_size, size_bytes=msg_bytes)
517+
518+
for strat in strategies:
519+
sn = strat.name
520+
# skip invalid combos
521+
skip = False
522+
if strat == AllReduceStrategy.TWOSHOT and num_tokens < world_size:
523+
skip = True
524+
elif strat in (AllReduceStrategy.ONESHOT, AllReduceStrategy.TWOSHOT) \
525+
and msg_bytes > max_workspace:
526+
skip = True
527+
elif strat == AllReduceStrategy.UB and fusion_op == AllReduceFusionOp.NONE:
528+
skip = True
529+
530+
if skip:
531+
row[f"{sn}_time"] = row[f"{sn}_algbw"] = None
532+
else:
533+
try:
534+
if strat == AllReduceStrategy.UB:
535+
t_us = _profile_ub(
536+
mapping, dist, ar_instances[strat], fusion_op,
537+
inp, res, norm, scale, enable_cudagraph,
538+
inner_loop, outer_loop)
539+
else:
540+
t_us = profile_allreduce(
541+
mapping=mapping, dist=dist,
542+
enable_cudagraph=enable_cudagraph,
543+
inner_loop=inner_loop, outer_loop=outer_loop,
544+
fusion=fusion_op, input=inp, residual=res,
545+
norm=norm, scale=scale,
546+
allreduce_instance=ar_instances[strat]) * 1000.0
547+
row[f"{sn}_time"] = t_us
548+
row[f"{sn}_algbw"] = msg_bytes / (t_us / 1e6) / 1e9
549+
except Exception as e:
550+
if rank == 0:
551+
print(f" [SKIP] {sn} @ {_fmt_size(msg_bytes)}: {e}", flush=True)
552+
row[f"{sn}_time"] = row[f"{sn}_algbw"] = None
553+
554+
csv_rows.append({
555+
"world_size": world_size, "dtype": dtype, "fusion": fusion_name,
556+
"num_tokens": num_tokens, "hidden_size": hidden_size,
557+
"size_bytes": msg_bytes, "strategy": sn,
558+
"time_us": row[f"{sn}_time"] or 0.0,
559+
"algbw_GBps": row[f"{sn}_algbw"] or 0.0,
560+
})
561+
table_rows.append(row)
562+
563+
if rank == 0:
564+
_print_table(fusion_name, strategy_names, table_rows, world_size)
565+
566+
# summary
567+
if rank == 0:
568+
print(f"\n{'=' * 80}", flush=True)
569+
print(" Summary: peak algbw (GB/s) per strategy per fusion", flush=True)
570+
print(f"{'=' * 80}", flush=True)
571+
hdr = f" {'fusion':<35s}" + "".join(f" {s:>14s}" for s in strategy_names)
572+
print(hdr, flush=True)
573+
print(" " + "-" * (len(hdr) - 2), flush=True)
574+
for fn, _ in fusions:
575+
line = f" {fn:<35s}"
576+
for sn in strategy_names:
577+
bws = [r["algbw_GBps"] for r in csv_rows
578+
if r["fusion"] == fn and r["strategy"] == sn and r["algbw_GBps"] > 0]
579+
line += f" {max(bws) if bws else 0.0:>14.2f}"
580+
print(line, flush=True)
581+
print(flush=True)
582+
583+
if rank == 0 and save_csv and csv_rows:
584+
with open(save_csv, "w", newline="") as f:
585+
writer = csv_mod.DictWriter(f, fieldnames=csv_rows[0].keys())
586+
writer.writeheader()
587+
writer.writerows(csv_rows)
588+
print(f"Results saved to {save_csv}", flush=True)
589+
590+
276591
if __name__ == "__main__":
277592
parser = ArgumentParser()
278593
parser.add_argument("--dtype", "-t", default="bfloat16")
@@ -285,14 +600,26 @@ def allreduce_benchmark(
285600
parser.add_argument("--enable_cudagraph", action="store_true")
286601
parser.add_argument("--save_csv", type=str, default=None)
287602
parser.add_argument("--enable_auto", action="store_true", default=False)
603+
parser.add_argument("--benchmark", action="store_true", default=False,
604+
help="Run comprehensive benchmark across all backends "
605+
"with nccl-tests style output")
288606

289607
args = parser.parse_args()
290608

291-
allreduce_benchmark(
292-
args.dtype,
293-
args.range,
294-
args.enable_cudagraph,
295-
args.explore_2d,
296-
args.save_csv,
297-
args.enable_auto,
298-
)
609+
if args.benchmark:
610+
allreduce_benchmark_all(
611+
dtype=args.dtype,
612+
test_range=args.range,
613+
explore_2d=args.explore_2d,
614+
enable_cudagraph=args.enable_cudagraph,
615+
save_csv=args.save_csv,
616+
)
617+
else:
618+
allreduce_benchmark(
619+
args.dtype,
620+
args.range,
621+
args.enable_cudagraph,
622+
args.explore_2d,
623+
args.save_csv,
624+
args.enable_auto,
625+
)

0 commit comments

Comments
 (0)