1- # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import os
1617from argparse import ArgumentParser
1718from itertools import product
1819
2627 from cuda import cudart
2728
2829import tensorrt_llm as tllm
30+ import tensorrt_llm .bindings .internal .userbuffers as ub
2931from tensorrt_llm import Mapping
3032from tensorrt_llm ._torch .autotuner import AutoTuner , autotune
33+ from tensorrt_llm ._torch .custom_ops .userbuffers_custom_ops import \
34+ copy_to_userbuffers
3135from tensorrt_llm ._torch .distributed import (AllReduce , AllReduceFusionOp ,
32- Distributed )
36+ Distributed ,
37+ userbuffers_allreduce_finalize )
3338from tensorrt_llm ._torch .modules .rms_norm import RMSNorm
3439from tensorrt_llm ._utils import (get_sm_version , local_mpi_rank , local_mpi_size ,
3540 nvtx_range )
@@ -52,6 +57,8 @@ def profile_allreduce(
5257 norm = None ,
5358 scale = None ,
5459 bias = None ,
60+ allreduce_instance = None ,
61+ dtype = None ,
5562):
5663
5764 allreduce_params = AllReduceParams (
@@ -63,7 +70,8 @@ def profile_allreduce(
6370 bias = bias ,
6471 )
6572
66- allreduce = AllReduce (mapping = mapping , strategy = strategy )
73+ allreduce = allreduce_instance or AllReduce (
74+ mapping = mapping , strategy = strategy , dtype = dtype )
6775
6876 def func (x , loop_num = inner_loop ):
6977 for _ in range (loop_num ):
@@ -273,6 +281,313 @@ def allreduce_benchmark(
273281 return df
274282
275283
284+ # ── nccl-tests style comprehensive benchmark (--benchmark mode) ──────────────
285+
286+ _STRATEGY_MAP = {
287+ "NCCL" : AllReduceStrategy .NCCL ,
288+ "NCCL_SYMMETRIC" : AllReduceStrategy .NCCL_SYMMETRIC ,
289+ "UB" : AllReduceStrategy .UB ,
290+ "ONESHOT" : AllReduceStrategy .ONESHOT ,
291+ "TWOSHOT" : AllReduceStrategy .TWOSHOT ,
292+ "AUTO" : AllReduceStrategy .AUTO ,
293+ "MNNVL" : AllReduceStrategy .MNNVL ,
294+ }
295+ _UB_STRATEGIES = {AllReduceStrategy .NCCL_SYMMETRIC , AllReduceStrategy .UB }
296+ _FUSION_MAP = {
297+ "NONE" : AllReduceFusionOp .NONE ,
298+ "RESIDUAL_RMS_NORM" : AllReduceFusionOp .RESIDUAL_RMS_NORM ,
299+ "RESIDUAL_RMS_NORM_QUANT_FP8" : AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_FP8 ,
300+ "RESIDUAL_RMS_NORM_QUANT_NVFP4" : AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 ,
301+ }
302+
303+
304+ def _fmt_size (nbytes ):
305+ """Format byte count as human-readable string (e.g. 256B, 4K, 1M, 2G)."""
306+ if nbytes < 1024 :
307+ return f"{ nbytes } B"
308+ elif nbytes < 1024 ** 2 :
309+ v = nbytes / 1024
310+ return f"{ v :.0f} K" if nbytes % 1024 == 0 else f"{ v :.1f} K"
311+ elif nbytes < 1024 ** 3 :
312+ v = nbytes / 1024 ** 2
313+ return f"{ v :.0f} M" if nbytes % (1024 ** 2 ) == 0 else f"{ v :.2f} M"
314+ else :
315+ v = nbytes / 1024 ** 3
316+ return f"{ v :.0f} G" if nbytes % (1024 ** 3 ) == 0 else f"{ v :.2f} G"
317+
318+
319+ def _profile_ub (mapping , dist , allreduce , fusion , input , residual , norm , scale ,
320+ enable_cudagraph = False , inner_loop = 200 , outer_loop = 10 ):
321+ """Profile UB allreduce kernel only (copy_to_ub and finalize are one-shot)."""
322+ allreduce_params = AllReduceParams (
323+ fusion_op = fusion , residual = residual , norm_weight = norm .weight ,
324+ eps = norm .variance_epsilon , scale = scale , bias = None )
325+
326+ # Copy input into user-buffer memory once (simulates matmul_to_ub in real flow)
327+ ub_input = copy_to_userbuffers (input )
328+
329+ def func (loop_num = inner_loop ):
330+ for _ in range (loop_num ):
331+ output = allreduce (ub_input , all_reduce_params = allreduce_params )
332+ return output
333+
334+ starts = [torch .cuda .Event (enable_timing = True ) for _ in range (outer_loop )]
335+ stops = [torch .cuda .Event (enable_timing = True ) for _ in range (outer_loop )]
336+ graph = torch .cuda .CUDAGraph ()
337+ stream = torch .cuda .Stream ()
338+ with torch .cuda .stream (stream ):
339+ # warmup
340+ for _ in range (4 ):
341+ func (loop_num = 1 )
342+ if enable_cudagraph :
343+ with torch .cuda .graph (graph , stream = stream ):
344+ func ()
345+ dist .barrier ()
346+ delay_kernel (20000 , stream )
347+ torch .cuda .synchronize ()
348+ for i in range (outer_loop ):
349+ starts [i ].record (stream )
350+ if enable_cudagraph :
351+ graph .replay ()
352+ else :
353+ func ()
354+ stops [i ].record (stream )
355+ torch .cuda .synchronize ()
356+ # Finalize once to sync (simulates userbuffers_allreduce_finalize in real flow)
357+ output = func (loop_num = 1 )
358+ userbuffers_allreduce_finalize (output [- 1 ])
359+ runtimes = [starts [i ].elapsed_time (stops [i ]) for i in range (outer_loop )]
360+ return sorted (runtimes )[len (runtimes ) // 2 ] / inner_loop * 1000.0
361+
362+
363+ def _print_table (fusion_name , strategy_names , rows , world_size ):
364+ W_S , W_T , W_H , W_V , W_B = 10 , 6 , 6 , 10 , 16
365+ n = len (strategy_names )
366+ print (flush = True )
367+ print (f"# Fusion: { fusion_name } world_size={ world_size } "
368+ f"algbw = size / time (GB/s)" , flush = True )
369+ print ("#" , flush = True )
370+ fixed = f"{ 'size' :>{W_S }} { 'ntok' :>{W_T }} { 'hdim' :>{W_H }} "
371+ sh = " " .join (f"{ s :^{W_V * 2 + 2 }} " for s in strategy_names )
372+ print (f"# { fixed } { sh } { 'BEST' :>{W_B }} " , flush = True )
373+ pad = " " * (W_S + 2 + W_T + 2 + W_H )
374+ mh = " " .join (f"{ 'time(us)' :>{W_V }} { 'algbw' :>{W_V }} " for _ in strategy_names )
375+ print (f"# { pad } { mh } { ' ' :>{W_B }} " , flush = True )
376+ tw = 2 + W_S + 2 + W_T + 2 + W_H + 2 + n * (W_V * 2 + 2 ) + (n - 1 ) * 2 + 2 + W_B
377+ print ("#" + "-" * (tw - 1 ), flush = True )
378+ for row in rows :
379+ prefix = (f" { row ['size_human' ]:>{W_S }} "
380+ f"{ row ['num_tokens' ]:>{W_T }} "
381+ f"{ row ['hidden_size' ]:>{W_H }} " )
382+ vals , best_name , best_time = [], "N/A" , float ("inf" )
383+ for s in strategy_names :
384+ t , bw = row .get (f"{ s } _time" ), row .get (f"{ s } _algbw" )
385+ if t is not None :
386+ vals .append (f"{ t :>{W_V }.2f} { bw :>{W_V }.2f} " )
387+ if t < best_time :
388+ best_time , best_name = t , s
389+ else :
390+ vals .append (f"{ 'N/A' :>{W_V }} { 'N/A' :>{W_V }} " )
391+ print (f"{ prefix } { ' ' .join (vals )} { best_name :>{W_B }} " , flush = True )
392+
393+
394+ def allreduce_benchmark_all (
395+ dtype = 'bfloat16' ,
396+ test_range = "256,268435456,2" ,
397+ explore_2d = False ,
398+ enable_cudagraph = False ,
399+ strategy_names = None ,
400+ fusion_names = None ,
401+ inner_loop = 200 ,
402+ outer_loop = 10 ,
403+ save_csv = None ,
404+ ):
405+ """Comprehensive benchmark: one table per fusion, all strategies side by side."""
406+ import csv as csv_mod
407+
408+ world_size = tllm .mpi_world_size ()
409+ rank = tllm .mpi_rank ()
410+ local_rank = local_mpi_rank ()
411+ gpus_per_node = local_mpi_size ()
412+
413+ torch .cuda .set_device (local_rank )
414+ cudart .cudaSetDevice (local_rank )
415+
416+ mapping = Mapping (world_size , rank , gpus_per_node , tp_size = world_size )
417+ logger .set_rank (mapping .rank )
418+ AutoTuner .get ().setup_distributed_state (mapping )
419+ dist = Distributed .get (mapping )
420+ sm_version = get_sm_version ()
421+
422+ if world_size == 1 :
423+ raise RuntimeError ("Benchmark requires mpi_world_size > 1" )
424+
425+ torch_dtype = tllm ._utils .str_dtype_to_torch (dtype )
426+ elem_size = torch .finfo (torch_dtype ).bits // 8
427+
428+ # Enable MNNVL testing on single-node (bypasses multi-node NVLink check)
429+ os .environ ["TLLM_TEST_MNNVL" ] = "1"
430+
431+ # strategies
432+ if strategy_names is None :
433+ strategy_names = ["NCCL" , "NCCL_SYMMETRIC" , "UB" , "ONESHOT" ,
434+ "TWOSHOT" , "AUTO" , "MNNVL" ]
435+ strategies = [_STRATEGY_MAP [s ] for s in strategy_names ]
436+
437+ # fusions
438+ if fusion_names is None :
439+ fusion_names = list (_FUSION_MAP .keys ())
440+ fusions = []
441+ for f in fusion_names :
442+ fop = _FUSION_MAP [f ]
443+ if fop == AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 and sm_version < 100 :
444+ if rank == 0 :
445+ print (f"[WARN] { f } requires SM100+, skipping." , flush = True )
446+ continue
447+ fusions .append ((f , fop ))
448+
449+ # shapes
450+ if explore_2d :
451+ num_tokens_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 ]
452+ hidden_size_list = [128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 ]
453+ shape_list = list (product (num_tokens_list , hidden_size_list ))
454+ else :
455+ min_bytes , max_bytes , ratio = [int (i ) for i in test_range .split ("," )]
456+ shape_list = []
457+ nbytes = min_bytes
458+ while nbytes <= max_bytes :
459+ total_elems = nbytes // elem_size
460+ if total_elems <= 4096 :
461+ shape_list .append ((1 , max (total_elems , 1 )))
462+ else :
463+ shape_list .append ((total_elems // 4096 , 4096 ))
464+ nbytes *= ratio
465+
466+ # init user-buffers
467+ need_ub = bool (_UB_STRATEGIES & set (strategies ))
468+ if need_ub :
469+ if ub .ub_supported ():
470+ max_elems = max (s [0 ] * s [1 ] for s in shape_list )
471+ ub .initialize_userbuffers_manager (
472+ world_size , 1 , 1 , rank , torch .cuda .device_count (),
473+ max_elems * elem_size )
474+ else :
475+ if rank == 0 :
476+ print ("[WARN] ub not supported, skipping UB-based strategies." , flush = True )
477+ strategies = [s for s in strategies if s not in _UB_STRATEGIES ]
478+ strategy_names = [s .name for s in strategies ]
479+
480+ # create AllReduce instances
481+ ar_instances = {}
482+ for strat in strategies :
483+ try :
484+ ar_instances [strat ] = AllReduce (mapping = mapping , strategy = strat , dtype = torch_dtype )
485+ except Exception as e :
486+ if rank == 0 :
487+ print (f"[WARN] Cannot init { strat .name } : { e } " , flush = True )
488+ strategies = [s for s in strategies if s in ar_instances ]
489+ strategy_names = [s .name for s in strategies ]
490+
491+ max_workspace = CustomAllReduceHelper .max_workspace_size_auto (mapping .tp_size )
492+
493+ if rank == 0 :
494+ print (f"\n { '=' * 80 } " , flush = True )
495+ print (" TRT-LLM AllReduce Benchmark" , flush = True )
496+ print (f" world_size={ world_size } dtype={ dtype } SM={ sm_version } "
497+ f" cudagraph={ enable_cudagraph } "
498+ f" inner={ inner_loop } outer={ outer_loop } " , flush = True )
499+ print (f" Strategies : { ', ' .join (strategy_names )} " , flush = True )
500+ print (f" Fusions : { ', ' .join (f for f , _ in fusions )} " , flush = True )
501+ print (f"{ '=' * 80 } " , flush = True )
502+
503+ csv_rows = []
504+
505+ for fusion_name , fusion_op in fusions :
506+ table_rows = []
507+ for num_tokens , hidden_size in shape_list :
508+ msg_bytes = num_tokens * hidden_size * elem_size
509+ inp = torch .ones ((num_tokens , hidden_size ), dtype = torch_dtype , device = "cuda" )
510+ res = torch .randn_like (inp )
511+ norm = RMSNorm (hidden_size = hidden_size , dtype = torch_dtype , eps = 1e-5 ).cuda ()
512+ norm .weight .data .copy_ (torch .randn ((hidden_size ,), dtype = torch_dtype , device = "cuda" ))
513+ scale = torch .tensor (1.0 , dtype = torch .float32 ).cuda ()
514+
515+ row = dict (size_human = _fmt_size (msg_bytes ), num_tokens = num_tokens ,
516+ hidden_size = hidden_size , size_bytes = msg_bytes )
517+
518+ for strat in strategies :
519+ sn = strat .name
520+ # skip invalid combos
521+ skip = False
522+ if strat == AllReduceStrategy .TWOSHOT and num_tokens < world_size :
523+ skip = True
524+ elif strat in (AllReduceStrategy .ONESHOT , AllReduceStrategy .TWOSHOT ) \
525+ and msg_bytes > max_workspace :
526+ skip = True
527+ elif strat == AllReduceStrategy .UB and fusion_op == AllReduceFusionOp .NONE :
528+ skip = True
529+
530+ if skip :
531+ row [f"{ sn } _time" ] = row [f"{ sn } _algbw" ] = None
532+ else :
533+ try :
534+ if strat == AllReduceStrategy .UB :
535+ t_us = _profile_ub (
536+ mapping , dist , ar_instances [strat ], fusion_op ,
537+ inp , res , norm , scale , enable_cudagraph ,
538+ inner_loop , outer_loop )
539+ else :
540+ t_us = profile_allreduce (
541+ mapping = mapping , dist = dist ,
542+ enable_cudagraph = enable_cudagraph ,
543+ inner_loop = inner_loop , outer_loop = outer_loop ,
544+ fusion = fusion_op , input = inp , residual = res ,
545+ norm = norm , scale = scale ,
546+ allreduce_instance = ar_instances [strat ]) * 1000.0
547+ row [f"{ sn } _time" ] = t_us
548+ row [f"{ sn } _algbw" ] = msg_bytes / (t_us / 1e6 ) / 1e9
549+ except Exception as e :
550+ if rank == 0 :
551+ print (f" [SKIP] { sn } @ { _fmt_size (msg_bytes )} : { e } " , flush = True )
552+ row [f"{ sn } _time" ] = row [f"{ sn } _algbw" ] = None
553+
554+ csv_rows .append ({
555+ "world_size" : world_size , "dtype" : dtype , "fusion" : fusion_name ,
556+ "num_tokens" : num_tokens , "hidden_size" : hidden_size ,
557+ "size_bytes" : msg_bytes , "strategy" : sn ,
558+ "time_us" : row [f"{ sn } _time" ] or 0.0 ,
559+ "algbw_GBps" : row [f"{ sn } _algbw" ] or 0.0 ,
560+ })
561+ table_rows .append (row )
562+
563+ if rank == 0 :
564+ _print_table (fusion_name , strategy_names , table_rows , world_size )
565+
566+ # summary
567+ if rank == 0 :
568+ print (f"\n { '=' * 80 } " , flush = True )
569+ print (" Summary: peak algbw (GB/s) per strategy per fusion" , flush = True )
570+ print (f"{ '=' * 80 } " , flush = True )
571+ hdr = f" { 'fusion' :<35s} " + "" .join (f" { s :>14s} " for s in strategy_names )
572+ print (hdr , flush = True )
573+ print (" " + "-" * (len (hdr ) - 2 ), flush = True )
574+ for fn , _ in fusions :
575+ line = f" { fn :<35s} "
576+ for sn in strategy_names :
577+ bws = [r ["algbw_GBps" ] for r in csv_rows
578+ if r ["fusion" ] == fn and r ["strategy" ] == sn and r ["algbw_GBps" ] > 0 ]
579+ line += f" { max (bws ) if bws else 0.0 :>14.2f} "
580+ print (line , flush = True )
581+ print (flush = True )
582+
583+ if rank == 0 and save_csv and csv_rows :
584+ with open (save_csv , "w" , newline = "" ) as f :
585+ writer = csv_mod .DictWriter (f , fieldnames = csv_rows [0 ].keys ())
586+ writer .writeheader ()
587+ writer .writerows (csv_rows )
588+ print (f"Results saved to { save_csv } " , flush = True )
589+
590+
276591if __name__ == "__main__" :
277592 parser = ArgumentParser ()
278593 parser .add_argument ("--dtype" , "-t" , default = "bfloat16" )
@@ -285,14 +600,26 @@ def allreduce_benchmark(
285600 parser .add_argument ("--enable_cudagraph" , action = "store_true" )
286601 parser .add_argument ("--save_csv" , type = str , default = None )
287602 parser .add_argument ("--enable_auto" , action = "store_true" , default = False )
603+ parser .add_argument ("--benchmark" , action = "store_true" , default = False ,
604+ help = "Run comprehensive benchmark across all backends "
605+ "with nccl-tests style output" )
288606
289607 args = parser .parse_args ()
290608
291- allreduce_benchmark (
292- args .dtype ,
293- args .range ,
294- args .enable_cudagraph ,
295- args .explore_2d ,
296- args .save_csv ,
297- args .enable_auto ,
298- )
609+ if args .benchmark :
610+ allreduce_benchmark_all (
611+ dtype = args .dtype ,
612+ test_range = args .range ,
613+ explore_2d = args .explore_2d ,
614+ enable_cudagraph = args .enable_cudagraph ,
615+ save_csv = args .save_csv ,
616+ )
617+ else :
618+ allreduce_benchmark (
619+ args .dtype ,
620+ args .range ,
621+ args .enable_cudagraph ,
622+ args .explore_2d ,
623+ args .save_csv ,
624+ args .enable_auto ,
625+ )
0 commit comments