44
55import os
66import sys
7+ import copy
8+ import json
9+ import traceback
710import logging
811from contextlib import nullcontext
912import torch
@@ -209,10 +212,10 @@ def run_dpa_with_cp(
209212 os .environ ["NVTE_FUSED_ATTN" ] = "0"
210213 if kernel_backend == "FlashAttention" :
211214 os .environ ["NVTE_FLASH_ATTN" ] = "1"
212- config = model_configs_flash_attn [model ]
215+ config = copy . deepcopy ( model_configs_flash_attn [model ])
213216 if kernel_backend == "FusedAttention" :
214217 os .environ ["NVTE_FUSED_ATTN" ] = "1"
215- config = model_configs_fused_attn [model ]
218+ config = copy . deepcopy ( model_configs_fused_attn [model ])
216219 assert config .attn_mask_type in [
217220 "causal" ,
218221 "no_mask" ,
@@ -223,18 +226,13 @@ def run_dpa_with_cp(
223226 else :
224227 config .attn_mask_type = "padding"
225228
226- # set up distributed group
227- rank = int (os .getenv ("RANK" , "0" ))
228- world_size = int (os .getenv ("WORLD_SIZE" , "1" ))
229- if dist .is_initialized ():
230- world_size = dist .get_world_size ()
231- rank = dist .get_rank ()
232- else :
233- device_count = torch .cuda .device_count ()
234- device = rank % device_count
235- torch .cuda .set_device (device )
229+ # Process group is managed by main(); one init/destroy per torchrun, not per config.
230+ assert dist .is_initialized (), (
231+ "dist.init_process_group must be called before run_dpa_with_cp"
232+ )
233+ world_size = dist .get_world_size ()
234+ rank = dist .get_rank ()
236235 logging .info (f"[Rank { rank } ] Setup: world_size { world_size } " )
237- dist .init_process_group (backend = "nccl" , world_size = world_size , rank = rank )
238236
239237 # set up communication group for CP
240238 cp_comm_ranks = range (world_size )
@@ -630,7 +628,6 @@ def run_dpa_with_cp(
630628 == 0
631629 )
632630 else :
633- # Forward-only: reshape only out/out_ for comparison
634631 out = out .index_select (0 , seq_idx_q ).contiguous ()
635632 out_ = out_
636633
@@ -762,14 +759,105 @@ def run_dpa_with_cp(
762759 )
763760 logging .info (f"[Rank { rank } ] CP vs no-CP: { names [i ]} matches" )
764761
765- # destroy distribution group
766- dist .destroy_process_group ()
762+ # Destroy per-config communication groups so they don't leak into the next
763+ # config in batch mode. The global process group is torn down by main().
764+ dist .destroy_process_group (cp_comm_group )
765+ if cp_comm_type == "a2a+p2p" :
766+ for sg in cp_comm_sub_groups :
767+ dist .destroy_process_group (sg )
768+
769+
770+ # Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage.
771+ _TRANSIENT_ENV_KEYS = (
772+ "NVTE_FP8_DPA_BWD" ,
773+ "NVTE_DPA_FP8CS_O_in_F16" ,
774+ "NVTE_FLASH_ATTN" ,
775+ "NVTE_FUSED_ATTN" ,
776+ "NVTE_ALLOW_NONDETERMINISTIC_ALGO" ,
777+ )
778+
779+
780+ def _init_distributed ():
781+ """Init NCCL process group + CUDA device once per torchrun invocation."""
782+ rank = int (os .getenv ("RANK" , "0" ))
783+ world_size = int (os .getenv ("WORLD_SIZE" , "1" ))
784+ device_count = torch .cuda .device_count ()
785+ # Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch);
786+ # fall back to RANK % device_count for single-node runs.
787+ local_rank = int (os .getenv ("LOCAL_RANK" , str (rank % device_count )))
788+ torch .cuda .set_device (local_rank )
789+ dist .init_process_group (backend = "nccl" , world_size = world_size , rank = rank )
790+ return rank , world_size
791+
792+
793+ def _run_single_config (kwargs ):
794+ """Run one config, return ``(ok, error_message)``.
795+
796+ Re-seeds RNG before each config so results are deterministic and
797+ order-independent within a batch.
798+ """
799+ torch .manual_seed (1234 )
800+ torch .cuda .manual_seed (1234 )
801+ try :
802+ run_dpa_with_cp (** kwargs )
803+ return True , None
804+ except BaseException : # noqa: BLE001 - capture any failure for per-config reporting
805+ return False , traceback .format_exc ()
767806
768807
769808def main (** kwargs ):
770- run_dpa_with_cp (** kwargs )
809+ """Entry point: single-config (``key=val ...``) or batch (``batch_config_json=<path>``)."""
810+ batch_path = kwargs .pop ("batch_config_json" , None )
811+ rank , _ = _init_distributed ()
812+ try :
813+ if batch_path is None :
814+ run_dpa_with_cp (** kwargs )
815+ else :
816+ with open (batch_path , "r" ) as f :
817+ configs = json .load (f )
818+ assert isinstance (configs , list ), (
819+ f"batch_config_json must be a JSON list, got { type (configs )} "
820+ )
821+ results_path = batch_path + ".results.json"
822+ results = []
823+
824+ def _flush_results ():
825+ if rank != 0 :
826+ return
827+ # Atomic write: tmp + rename so the reader never sees partial JSON.
828+ tmp_path = results_path + ".tmp"
829+ with open (tmp_path , "w" ) as f :
830+ json .dump (results , f )
831+ os .replace (tmp_path , results_path )
832+
833+ for cfg in configs :
834+ for env_key in _TRANSIENT_ENV_KEYS :
835+ os .environ .pop (env_key , None )
836+ ok , err = _run_single_config (cfg )
837+ # Aggregate ok across ranks so a non-rank-0 failure (e.g. a
838+ # per-partition compare assertion that fires only on rank > 0)
839+ # is not silently swallowed when only rank 0 writes the result.
840+ ok_tensor = torch .tensor (1 if ok else 0 , dtype = torch .int32 , device = "cuda" )
841+ dist .all_reduce (ok_tensor , op = dist .ReduceOp .MIN )
842+ ok_aggregate = bool (ok_tensor .item ())
843+ if not ok_aggregate and ok and err is None :
844+ err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
845+ results .append ({"ok" : ok_aggregate , "error" : err })
846+ _flush_results ()
847+ try :
848+ dist .barrier ()
849+ except BaseException : # noqa: BLE001
850+ results [- 1 ]["ok" ] = False
851+ if results [- 1 ]["error" ] is None :
852+ results [- 1 ]["error" ] = traceback .format_exc ()
853+ _flush_results ()
854+ break
855+ torch .cuda .empty_cache ()
856+ finally :
857+ if dist .is_initialized ():
858+ dist .destroy_process_group ()
771859
772860
773861if __name__ == "__main__" :
774- kwargs = dict (arg .split ("=" ) for arg in sys .argv [2 :])
862+ kwargs = dict (arg .split ("=" , 1 ) for arg in sys .argv [2 :])
775863 main (** kwargs )
0 commit comments