Skip to content

Commit c80df5d

Browse files
[PyTorch] Batch CP attention tests in single torchrun to amortize NCCL init
Each parametrized CP test currently spawns its own torchrun process and pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this adds up to 1.5-3 hours of pure setup overhead. This change introduces a session-scoped fixture that: 1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or a kwargs dict for the worker. 2. Groups runnable configs by ``num_gpus`` and chunks them into batches of CP_TEST_BATCH_SIZE (default 16). 3. Launches one torchrun per chunk; the worker initialises NCCL once and runs all configs in the chunk inside the same world. Per-config results are flushed to JSON after every config so a crash mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1 to bisect a failing batch. Also includes a small bugfix in dot_product_attention/utils.py: the deterministic-FA3 disable condition was firing for any head_dim_qk > 128 (including inference); restrict it to is_training and large head dims. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1 parent 4b6923d commit c80df5d

2 files changed

Lines changed: 377 additions & 48 deletions

File tree

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import os
66
import sys
7+
import copy
8+
import json
9+
import traceback
710
import logging
811
from contextlib import nullcontext
912
import torch
@@ -209,10 +212,10 @@ def run_dpa_with_cp(
209212
os.environ["NVTE_FUSED_ATTN"] = "0"
210213
if kernel_backend == "FlashAttention":
211214
os.environ["NVTE_FLASH_ATTN"] = "1"
212-
config = model_configs_flash_attn[model]
215+
config = copy.deepcopy(model_configs_flash_attn[model])
213216
if kernel_backend == "FusedAttention":
214217
os.environ["NVTE_FUSED_ATTN"] = "1"
215-
config = model_configs_fused_attn[model]
218+
config = copy.deepcopy(model_configs_fused_attn[model])
216219
assert config.attn_mask_type in [
217220
"causal",
218221
"no_mask",
@@ -223,18 +226,13 @@ def run_dpa_with_cp(
223226
else:
224227
config.attn_mask_type = "padding"
225228

226-
# set up distributed group
227-
rank = int(os.getenv("RANK", "0"))
228-
world_size = int(os.getenv("WORLD_SIZE", "1"))
229-
if dist.is_initialized():
230-
world_size = dist.get_world_size()
231-
rank = dist.get_rank()
232-
else:
233-
device_count = torch.cuda.device_count()
234-
device = rank % device_count
235-
torch.cuda.set_device(device)
229+
# Process group is managed by main(); one init/destroy per torchrun, not per config.
230+
assert dist.is_initialized(), (
231+
"dist.init_process_group must be called before run_dpa_with_cp"
232+
)
233+
world_size = dist.get_world_size()
234+
rank = dist.get_rank()
236235
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
237-
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
238236

239237
# set up communication group for CP
240238
cp_comm_ranks = range(world_size)
@@ -630,7 +628,6 @@ def run_dpa_with_cp(
630628
== 0
631629
)
632630
else:
633-
# Forward-only: reshape only out/out_ for comparison
634631
out = out.index_select(0, seq_idx_q).contiguous()
635632
out_ = out_
636633

@@ -762,14 +759,105 @@ def run_dpa_with_cp(
762759
)
763760
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
764761

765-
# destroy distribution group
766-
dist.destroy_process_group()
762+
# Destroy per-config communication groups so they don't leak into the next
763+
# config in batch mode. The global process group is torn down by main().
764+
dist.destroy_process_group(cp_comm_group)
765+
if cp_comm_type == "a2a+p2p":
766+
for sg in cp_comm_sub_groups:
767+
dist.destroy_process_group(sg)
768+
769+
770+
# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage.
771+
_TRANSIENT_ENV_KEYS = (
772+
"NVTE_FP8_DPA_BWD",
773+
"NVTE_DPA_FP8CS_O_in_F16",
774+
"NVTE_FLASH_ATTN",
775+
"NVTE_FUSED_ATTN",
776+
"NVTE_ALLOW_NONDETERMINISTIC_ALGO",
777+
)
778+
779+
780+
def _init_distributed():
781+
"""Init NCCL process group + CUDA device once per torchrun invocation."""
782+
rank = int(os.getenv("RANK", "0"))
783+
world_size = int(os.getenv("WORLD_SIZE", "1"))
784+
device_count = torch.cuda.device_count()
785+
# Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch);
786+
# fall back to RANK % device_count for single-node runs.
787+
local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count)))
788+
torch.cuda.set_device(local_rank)
789+
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
790+
return rank, world_size
791+
792+
793+
def _run_single_config(kwargs):
794+
"""Run one config, return ``(ok, error_message)``.
795+
796+
Re-seeds RNG before each config so results are deterministic and
797+
order-independent within a batch.
798+
"""
799+
torch.manual_seed(1234)
800+
torch.cuda.manual_seed(1234)
801+
try:
802+
run_dpa_with_cp(**kwargs)
803+
return True, None
804+
except BaseException: # noqa: BLE001 - capture any failure for per-config reporting
805+
return False, traceback.format_exc()
767806

768807

769808
def main(**kwargs):
770-
run_dpa_with_cp(**kwargs)
809+
"""Entry point: single-config (``key=val ...``) or batch (``batch_config_json=<path>``)."""
810+
batch_path = kwargs.pop("batch_config_json", None)
811+
rank, _ = _init_distributed()
812+
try:
813+
if batch_path is None:
814+
run_dpa_with_cp(**kwargs)
815+
else:
816+
with open(batch_path, "r") as f:
817+
configs = json.load(f)
818+
assert isinstance(configs, list), (
819+
f"batch_config_json must be a JSON list, got {type(configs)}"
820+
)
821+
results_path = batch_path + ".results.json"
822+
results = []
823+
824+
def _flush_results():
825+
if rank != 0:
826+
return
827+
# Atomic write: tmp + rename so the reader never sees partial JSON.
828+
tmp_path = results_path + ".tmp"
829+
with open(tmp_path, "w") as f:
830+
json.dump(results, f)
831+
os.replace(tmp_path, results_path)
832+
833+
for cfg in configs:
834+
for env_key in _TRANSIENT_ENV_KEYS:
835+
os.environ.pop(env_key, None)
836+
ok, err = _run_single_config(cfg)
837+
# Aggregate ok across ranks so a non-rank-0 failure (e.g. a
838+
# per-partition compare assertion that fires only on rank > 0)
839+
# is not silently swallowed when only rank 0 writes the result.
840+
ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda")
841+
dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN)
842+
ok_aggregate = bool(ok_tensor.item())
843+
if not ok_aggregate and ok and err is None:
844+
err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
845+
results.append({"ok": ok_aggregate, "error": err})
846+
_flush_results()
847+
try:
848+
dist.barrier()
849+
except BaseException: # noqa: BLE001
850+
results[-1]["ok"] = False
851+
if results[-1]["error"] is None:
852+
results[-1]["error"] = traceback.format_exc()
853+
_flush_results()
854+
break
855+
torch.cuda.empty_cache()
856+
finally:
857+
if dist.is_initialized():
858+
dist.destroy_process_group()
771859

772860

773861
if __name__ == "__main__":
774-
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
862+
kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:])
775863
main(**kwargs)

0 commit comments

Comments
 (0)