From 6355f6208e2d35027d1c7e537e7c337131741753 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 6 May 2026 15:21:01 -0700 Subject: [PATCH 1/7] [PyTorch] Batch CP attention tests in single torchrun to amortize NCCL init Each parametrized CP test currently spawns its own torchrun process and pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this adds up to 1.5-3 hours of pure setup overhead. This change introduces a session-scoped fixture that: 1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or a kwargs dict for the worker. 2. Groups runnable configs by ``num_gpus`` and chunks them into batches of CP_TEST_BATCH_SIZE (default 16). 3. Launches one torchrun per chunk; the worker initialises NCCL once and runs all configs in the chunk inside the same world. Per-config results are flushed to JSON after every config so a crash mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1 to bisect a failing batch. Also includes a small bugfix in dot_product_attention/utils.py: the deterministic-FA3 disable condition was firing for any head_dim_qk > 128 (including inference); restrict it to is_training and large head dims. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 137 +++++++- .../attention/test_attention_with_cp.py | 326 ++++++++++++++++-- 2 files changed, 415 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..2b4dbbd166 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -4,6 +4,9 @@ import os import sys +import copy +import json +import traceback import logging from contextlib import nullcontext import torch @@ -209,10 +212,10 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -223,18 +226,13 @@ def run_dpa_with_cp( else: config.attn_mask_type = "padding" - # set up distributed group - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) + # Process group is managed by main(); one init/destroy per torchrun, not per config. + assert dist.is_initialized(), ( + "dist.init_process_group must be called before run_dpa_with_cp" + ) + world_size = dist.get_world_size() + rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # set up communication group for CP cp_comm_ranks = range(world_size) @@ -630,7 +628,6 @@ def run_dpa_with_cp( == 0 ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ @@ -762,14 +759,118 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # destroy distribution group - dist.destroy_process_group() + # Destroy per-config communication groups so they don't leak into the next + # config in batch mode. The global process group is torn down by main(). + dist.destroy_process_group(cp_comm_group) + if cp_comm_type == "a2a+p2p": + for sg in cp_comm_sub_groups: + dist.destroy_process_group(sg) + + +# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage. +_TRANSIENT_ENV_KEYS = ( + "NVTE_FP8_DPA_BWD", + "NVTE_DPA_FP8CS_O_in_F16", + "NVTE_FLASH_ATTN", + "NVTE_FUSED_ATTN", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", +) + + +def _init_distributed(): + """Init NCCL process group + CUDA device once per torchrun invocation.""" + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + device_count = torch.cuda.device_count() + # Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch); + # fall back to RANK % device_count for single-node runs. + local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count))) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + return rank, world_size + + +def _run_single_config(kwargs): + """Run one config, return ``(ok, error_message)``. + + Re-seeds RNG before each config so results are deterministic and + order-independent within a batch. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + try: + run_dpa_with_cp(**kwargs) + return True, None + except BaseException: # noqa: BLE001 - capture any failure for per-config reporting + return False, traceback.format_exc() def main(**kwargs): - run_dpa_with_cp(**kwargs) + """Entry point. + + Two modes: + + * Single-config (legacy): ``run_attention_with_cp.py key=val ...`` runs + one config, propagates exceptions for normal exit-code signalling. + * Batch: ``run_attention_with_cp.py batch_config_json=`` reads a + JSON list of kwargs dicts, runs each via ``_run_single_config``, + aggregates ``ok`` across ranks (any rank failure → False), and flushes + ``[{ok,error}, ...]`` atomically to ``.results.json`` after each + config so a worker crash mid-batch leaves earlier results intact. + Transient env vars are reset between configs; per-config NCCL groups + are torn down inside ``run_dpa_with_cp``. + """ + batch_path = kwargs.pop("batch_config_json", None) + rank, _ = _init_distributed() + try: + if batch_path is None: + run_dpa_with_cp(**kwargs) + else: + with open(batch_path, "r") as f: + configs = json.load(f) + assert isinstance(configs, list), ( + f"batch_config_json must be a JSON list, got {type(configs)}" + ) + results_path = batch_path + ".results.json" + results = [] + + def _flush_results(): + if rank != 0: + return + # Atomic write: tmp + rename so the reader never sees partial JSON. + tmp_path = results_path + ".tmp" + with open(tmp_path, "w") as f: + json.dump(results, f) + os.replace(tmp_path, results_path) + + for cfg in configs: + for env_key in _TRANSIENT_ENV_KEYS: + os.environ.pop(env_key, None) + ok, err = _run_single_config(cfg) + # Aggregate ok across ranks so a non-rank-0 failure (e.g. a + # per-partition compare assertion that fires only on rank > 0) + # is not silently swallowed when only rank 0 writes the result. + ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda") + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + ok_aggregate = bool(ok_tensor.item()) + if not ok_aggregate and ok and err is None: + err = "Failed on a non-zero rank (see subprocess stderr for traceback)" + results.append({"ok": ok_aggregate, "error": err}) + _flush_results() + try: + dist.barrier() + except BaseException: # noqa: BLE001 + results[-1]["ok"] = False + if results[-1]["error"] is None: + results[-1]["error"] = traceback.format_exc() + _flush_results() + break + torch.cuda.empty_cache() + finally: + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == "__main__": - kwargs = dict(arg.split("=") for arg in sys.argv[2:]) + kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:]) main(**kwargs) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..f07900be83 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import os -import subprocess +import json import sys +import tempfile import pathlib import logging import copy @@ -41,6 +42,8 @@ test_essential = True +_BATCH_SIZE = int(os.getenv("CP_TEST_BATCH_SIZE", "16")) + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -75,6 +78,251 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +# --------------------------------------------------------------------------- +# Batched dispatch — keeps test bodies identical to the non-batched flow +# (parametrize stack + inline ``pytest.skip(...)``) and replaces only the +# final ``run_distributed(...)`` call with ``_run_or_fetch(...)``. +# +# Flow: +# 1. Collect (dry-run, in-process). Session fixture ``_cp_batch_results`` +# walks each parametrized item that requests this fixture, calls it +# with a stub ``request`` (only ``request.node.nodeid``). The body runs +# its ``pytest.skip(...)`` checks; if none fire, ``_run_or_fetch`` +# records the kwargs in ``_COLLECTED_KWARGS`` instead of launching +# torchrun. ``@pytest.mark.skip(if)`` markers are evaluated up front +# via ``_item_static_skip`` so marker-skipped items aren't queued. +# 2. Batch + execute. Recorded kwargs are grouped by num_gpus_per_node, +# chunked into batches of CP_TEST_BATCH_SIZE (default 16), and each +# batch runs in one torchrun (``_run_one_batch``). Worker +# (run_attention_with_cp.py) inits NCCL once, loops over configs, +# flushes per-config results to ``.results.json`` atomically. +# 3. Execute mode (normal pytest run). The test body re-evaluates its +# skip checks; if none fire, ``_run_or_fetch`` looks up the recorded +# result by nodeid and asserts pass/fail. +# +# Failure handling: +# - Inline ``pytest.skip``: same code path as non-batched. +# - Worker assertion: surfaced as ``AssertionError`` from the JSON entry. +# - Per-rank failure: cross-rank ``dist.all_reduce(ok, op=MIN)`` in the +# worker so a rank > 0 assertion isn't swallowed by the rank-0-only flush. +# - Worker subprocess crash mid-batch: configs without flushed results are +# marked unattributed and ``_run_one_batch`` retries each as a singleton +# to identify the actual culprit. Disable via ``CP_TEST_BATCH_RETRY=0``. +# - Dry-run exception: caught; the same error fires in execute mode and +# pytest reports it as a normal test ERROR (no fixture-level cascade). +# +# To add a new batched test: write it like a non-batched CP test +# (parametrize + inline ``pytest.skip(...)``), accept +# ``request, _cp_batch_results`` as fixtures, and replace +# ``run_distributed(get_bash_arguments(...))`` with +# ``_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)``. +# Nothing else needs wiring up. +# +# Knobs: +# CP_TEST_BATCH_SIZE=N configs per torchrun; default 16; set 1 to bisect. +# CP_TEST_BATCH_RETRY=0 skip the singleton retry on unattributed crashes. +# +# Caveats: +# - ``pytest -k`` reduces what's collected and therefore what's batched. +# - The body executes once per item in collect mode (cheap; only Python +# skip logic + ``get_available_attention_backends``). +# - Mutations to module-level state during the body persist between collect +# and execute. Worker uses ``copy.deepcopy(model_configs_*[model])`` so +# ``run_dpa_with_cp`` mutations don't leak across configs in a batch. +# --------------------------------------------------------------------------- + +# Module-level state used by the session fixture's collect phase. +_COLLECT_MODE = False +_COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) + + +def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): + """Drop-in replacement for ``run_distributed(get_bash_arguments(...))``. + + In *collect mode* (during the session fixture's first pass), records this + test's kwargs so the fixture can batch them. In *execute mode* (the normal + test run), looks up the pre-computed result and either passes, fails, or + skips. + """ + if _COLLECT_MODE: + _COLLECTED_KWARGS[request.node.nodeid] = dict( + num_gpus=num_gpus_per_node, **worker_kwargs + ) + return + entry = batch_results.get(request.node.nodeid) + if entry is None: + pytest.skip("No batched result recorded (collection mismatch).") + if not entry.get("ok", False): + raise AssertionError(entry.get("error") or "Batched config failed (no error captured)") + + +def _run_batch_once(num_gpus, configs): + """Launch one torchrun that runs *configs* sequentially inside one NCCL world. + + Returns a list of ``{"ok": bool, "error": str|None}`` dicts, one per config. + Missing entries (subprocess crashed mid-batch) are synthesized as failures. + """ + # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. + # Strip ``num_gpus`` (launcher-only, not a worker kwarg). + worker_kwargs = [ + {k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs + ] + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".cp_batch.json", delete=False + ) as fh: + batch_path = fh.name + json.dump(worker_kwargs, fh) + results_path = batch_path + ".results.json" + + try: + argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) + launch_err = None + try: + run_distributed(argv) + except AssertionError as exc: + launch_err = str(exc) + + try: + with open(results_path, "r") as f: + per_cfg = json.load(f) + except (OSError, json.JSONDecodeError): + per_cfg = [] + + results = [] + for i in range(len(configs)): + if i < len(per_cfg): + results.append(per_cfg[i]) + else: + results.append( + { + "ok": False, + "error": launch_err or "Subprocess exited before this config ran.", + "_unattributed": True, + } + ) + return results + finally: + for p in (batch_path, results_path): + try: + os.unlink(p) + except OSError: + pass + + +def _run_one_batch(num_gpus, configs): + """Run a batch, then retry any unattributed-crash entries as singletons. + + When the worker subprocess crashes (segfault / NCCL hang / OOM) before it + can flush a per-config result, every config past the crash gets the + generic "Subprocess exited before this config ran" marker and we don't + know which config is the actual culprit. Re-running each marked config in + its own torchrun pinpoints which one crashes on its own (and salvages a + real result for the ones that ran fine but were caught downstream of the + crash). + + Disable via ``CP_TEST_BATCH_RETRY=0`` — useful if the singleton retries + themselves are taking too long on a flaky cluster. + """ + results = _run_batch_once(num_gpus, configs) + if len(configs) <= 1 or not int(os.getenv("CP_TEST_BATCH_RETRY", "1")): + for r in results: + r.pop("_unattributed", None) + return results + for i, r in enumerate(results): + if r.pop("_unattributed", False): + results[i] = _run_batch_once(num_gpus, [configs[i]])[0] + results[i].pop("_unattributed", None) + return results + + +class _DummyRequest: + """Stand-in for the ``request`` fixture during the dry-run phase. + + The test body only touches ``request.node.nodeid``, so this is enough. + """ + + def __init__(self, nodeid): + self.node = type("_DummyNode", (), {"nodeid": nodeid})() + + +def _item_static_skip(item): + """Return True if pytest's static skip/skipif markers would skip *item*. + + These markers are evaluated by pytest at runtime, before the test body + runs. The dry-run calls ``item.function(...)`` directly and would bypass + them — we replicate the check here so a marker-skipped test isn't queued + for torchrun unnecessarily. + """ + for marker in item.iter_markers("skip"): + return True + for marker in item.iter_markers("skipif"): + cond = marker.args[0] if marker.args else marker.kwargs.get("condition") + if cond: + return True + return False + + +def _dry_run_item(item): + """Invoke a parametrized test body in collect mode. + + Raises ``pytest.skip.Exception`` if the body skips, otherwise returns + after ``_run_or_fetch`` has stashed the kwargs in ``_COLLECTED_KWARGS``. + """ + func = item.function + params = dict(item.callspec.params) + func(_DummyRequest(item.nodeid), {}, **params) + + +@pytest.fixture(scope="session") +def _cp_batch_results(request): + """Run all batched test bodies once in collect mode, then run torchrun batches. + + Skips are NOT tracked here — the test body raises ``pytest.skip(...)`` in + both collect and execute mode, so skipped tests never reach ``_run_or_fetch`` + and don't need an entry in the result map. + """ + global _COLLECT_MODE + + items = [ + it + for it in request.session.items + if "_cp_batch_results" in getattr(it, "fixturenames", ()) + ] + + _COLLECTED_KWARGS.clear() + _COLLECT_MODE = True + try: + for item in items: + if _item_static_skip(item): + continue # pytest will skip this at runtime; don't queue for torchrun + try: + _dry_run_item(item) + except pytest.skip.Exception: + pass # the same pytest.skip will fire again in execute mode + except BaseException: # noqa: BLE001 + # Don't let a single bad item kill the whole session fixture — + # pytest will re-raise the same error in execute mode and the + # failure will surface there as a normal test ERROR. + pass + finally: + _COLLECT_MODE = False + + by_num_gpus = {} + for nodeid, kwargs in _COLLECTED_KWARGS.items(): + num_gpus = kwargs.pop("num_gpus") + by_num_gpus.setdefault(num_gpus, []).append((nodeid, kwargs)) + + results = {} + for num_gpus, entries in by_num_gpus.items(): + for start in range(0, len(entries), _BATCH_SIZE): + chunk = entries[start : start + _BATCH_SIZE] + chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + for (nodeid, _), res in zip(chunk, chunk_results): + results[nodeid] = res + return results + + dtypes = ["bf16", "fp16"] qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] @@ -91,7 +339,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention( + request, _cp_batch_results, dtype, model, qkv_format, cp_comm_type +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -140,16 +390,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,7 +524,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + request, + _cp_batch_results, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True @@ -386,6 +646,7 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) @@ -404,21 +665,26 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, ) From 686e76b143dc07557b50d03b743c5de98c2b8651 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 14:14:54 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 10 ++++------ .../pytorch/attention/test_attention_with_cp.py | 16 ++++------------ 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 2b4dbbd166..a78df40dad 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -227,9 +227,7 @@ def run_dpa_with_cp( config.attn_mask_type = "padding" # Process group is managed by main(); one init/destroy per torchrun, not per config. - assert dist.is_initialized(), ( - "dist.init_process_group must be called before run_dpa_with_cp" - ) + assert dist.is_initialized(), "dist.init_process_group must be called before run_dpa_with_cp" world_size = dist.get_world_size() rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") @@ -828,9 +826,9 @@ def main(**kwargs): else: with open(batch_path, "r") as f: configs = json.load(f) - assert isinstance(configs, list), ( - f"batch_config_json must be a JSON list, got {type(configs)}" - ) + assert isinstance( + configs, list + ), f"batch_config_json must be a JSON list, got {type(configs)}" results_path = batch_path + ".results.json" results = [] diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index f07900be83..fb2f77a689 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -145,9 +145,7 @@ def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs) skips. """ if _COLLECT_MODE: - _COLLECTED_KWARGS[request.node.nodeid] = dict( - num_gpus=num_gpus_per_node, **worker_kwargs - ) + _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs) return entry = batch_results.get(request.node.nodeid) if entry is None: @@ -164,13 +162,9 @@ def _run_batch_once(num_gpus, configs): """ # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. # Strip ``num_gpus`` (launcher-only, not a worker kwarg). - worker_kwargs = [ - {k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs - ] + worker_kwargs = [{k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs] - with tempfile.NamedTemporaryFile( - mode="w", suffix=".cp_batch.json", delete=False - ) as fh: + with tempfile.NamedTemporaryFile(mode="w", suffix=".cp_batch.json", delete=False) as fh: batch_path = fh.name json.dump(worker_kwargs, fh) results_path = batch_path + ".results.json" @@ -285,9 +279,7 @@ def _cp_batch_results(request): global _COLLECT_MODE items = [ - it - for it in request.session.items - if "_cp_batch_results" in getattr(it, "fixturenames", ()) + it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) ] _COLLECTED_KWARGS.clear() From 7a42f7bf16c227ee4fe8370e6cf78760dbf811c0 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 7 May 2026 15:08:49 -0700 Subject: [PATCH 3/7] [PyTorch] Guarantee CP NCCL group cleanup with try/finally Wrap run_dpa_with_cp body in try/finally so cp_comm_group (and the a2a+p2p sub-groups) are destroyed even when the body raises -- e.g. an output-mismatch assertion, CUDA OOM, or a backend exception. Without this, a failing config inside a batched torchrun leaked communicators that accumulated across the 16-config batch window, exhausted NCCL's internal communicator table, and corrupted subsequent configs with opaque "not initialized in the world group map" errors (see B200 run in PR 2965 where one batch failure cascaded to 143 identical fall-through failures). Hoist the cp_comm_sub_groups = [] initialisation out of the a2a+p2p branch so finally always has a list to iterate, and drop the now-redundant cp_comm_type == "a2a+p2p" guard around the destroy loop (empty list is a no-op). Addresses Greptile P1 review comment on PR 2965. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 964 +++++++++--------- 1 file changed, 482 insertions(+), 482 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index a78df40dad..ec00bc01eb 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -236,6 +236,7 @@ def run_dpa_with_cp( cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + cp_comm_sub_groups = [] if cp_comm_type == "a2a+p2p": assert world_size % 2 == 0, ( "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" @@ -243,524 +244,523 @@ def run_dpa_with_cp( ) cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] for sub_ranks in cp_comm_sub_ranks: sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) - if dtype == "fp8": + try: + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, + ).cuda() + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() if scaling_mode == "delayed": - fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) if scaling_mode == "current": - fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) - - # instantiate attention module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - softmax_type=config.softmax_type, - return_max_logit=config.return_max_logit, - ).cuda() - if not is_training: - core_attn.eval() - if is_training and config.softmax_type != "vanilla": - core_attn.softmax_offset.requires_grad = True - - # generate attention inputs - ( - q_input_shape, - k_input_shape, - v_input_shape, - attn_output_shape, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - dout_orig = torch.clamp( - torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 - ).cuda() - if scaling_mode == "delayed": - qkv_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - if scaling_mode == "current": - qkv_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - device="cuda", - ) - dout_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - device="cuda", - ) - if scaling_mode == "mxfp8": - qkv_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, - ) - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - dout_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - rowwise=True, - columnwise=True, - ) - dout_quantizer.optimize_for_gemm = True - dout_quantizer.internal = False - qkv_layout = "_".join([qkv_format] * 3) - q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) - for x in [q, k, v]: - x.requires_grad = True - - if config.attn_bias_type not in ["no_bias", "alibi"]: - bias_shape_map = { - "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), - "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), - "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), - "bhss": ( - config.batch_size, - config.num_heads, - config.max_seqlen_q, - config.max_seqlen_kv, - ), - "111s": (1, 1, 1, config.max_seqlen_kv), - } - attn_bias_shape = bias_shape_map.get(config.bias_shape) - if attn_bias_shape is None: - assert False, f"cuDNN does not support {config.bias_shape=}" - bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() - # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 - # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported - bias.requires_grad = True if config.bias_shape != "111s" else False - else: - bias = None - - ############ run without CP ############ - logging.info(f"[Rank {rank}] Run without context parallelism") - if dtype == "fp8": - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - max_logit = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out, max_logit = out - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) - if is_training: - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None - d_softmax_offset = ( - core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None - ) - else: - dq, dk, dv, dbias = None, None, None, None - d_softmax_offset = None - - ############ run with CP ############ - logging.info(f"[Rank {rank}] Run with context parallelism") - - # set up inputs - q_, k_, v_, dout_, *rest = [ - x.clone().detach() - for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) - ] - bias_ = rest[0] if len(rest) else None - if qkv_format == "bshd" or qkv_format == "sbhd": - seq_dim = qkv_format.index("s") - q_, k_, v_, dout_ = [ - x.view( - *x.shape[:seq_dim], - 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, ) - for x in [q_, k_, v_, dout_] - ] - seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) - q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] - ] - elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices( - cu_seqlens_q_padded, q_.shape[0], world_size, rank - ) - seq_idx_kv = tex.thd_get_partitioned_indices( - cu_seqlens_kv_padded, k_.shape[0], world_size, rank - ) - q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] - k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] - if scaling_mode == "delayed": - qkv_quantizer.scale.fill_(1.0) - qkv_quantizer.amax.fill_(0.0) - dout_quantizer.scale.fill_(1.0) - dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - if is_training: - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] - if bias_ is not None: - ndim = bias_.ndim - seq_q_dim = ndim - 2 - if qkv_format == "thd": - bias_seq_idx = seq_idx_q + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + + if config.attn_bias_type not in ["no_bias", "alibi"]: + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" + bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False else: - bias_seq_idx = seq_idx - shape_before_seq = bias_.shape[:seq_q_dim] - seq_q_size = bias_.shape[seq_q_dim] - seq_kv_size = bias_.shape[-1] - if seq_q_size == 1: - # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s - bias_.requires_grad = False - # Bias is broadcast, no need to partition along sequence dimension - pass + bias = None + + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") + if dtype == "fp8": + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: - bias_ = bias_.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + fp8_context = nullcontext() + max_logit = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) - bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) - bias_.requires_grad = True - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.fp8_initialized = False - core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - - # run attention - max_logit_ = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out_, max_logit_ = out_ - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - if is_training: - dq_, dk_, dv_, dbias_ = ( - q_.grad, - k_.grad, - v_.grad, - bias_.grad if bias_ is not None else None, - ) - d_softmax_offset_ = ( - core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None - ) - else: - dq_, dk_, dv_, dbias_ = None, None, None, None - d_softmax_offset_ = None - - # get outputs - tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] - names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] - if fp8_mha: - tensors_to_deq = [out, out_] if not fp8_bwd else tensors - for i, tensor in enumerate(tensors_to_deq): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - tensors_to_deq[i] = tensor.dequantize() - if not fp8_bwd: - tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" - out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors - - ############ compare results between CP and no-CP ############ - if qkv_format == "bshd" or qkv_format == "sbhd": + if config.return_max_logit: + out, max_logit = out + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) if is_training: - dq, dk, dv, out = [ + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up inputs + q_, k_, v_, dout_, *rest = [ + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) + ] + bias_ = rest[0] if len(rest) else None + if qkv_format == "bshd" or qkv_format == "sbhd": + seq_dim = qkv_format.index("s") + q_, k_, v_, dout_ = [ x.view( *x.shape[:seq_dim], 2 * world_size, x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [dq, dk, dv, out] + for x in [q_, k_, v_, dout_] ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) + q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [ + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] ] - if dbias is not None and dbias_ is not None: - ndim = dbias.ndim - # Query seq is at dim -2 - seq_q_dim = ndim - 2 - shape_before_seq = dbias.shape[:seq_q_dim] - seq_q_size = dbias.shape[seq_q_dim] - seq_kv_size = dbias.shape[-1] - # Reshape to split seq_q dimension - dbias = dbias.view( + elif qkv_format == "thd": + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) + q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] + k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if bias_ is not None: + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass + else: + bias_ = bias_.view( *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) - # Index select on the newly created dimension (now at position seq_q_dim) - dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view( - *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size - ) + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: - # Forward-only: reshape only out/out_ for comparison - out = out.view( - *out.shape[:seq_dim], - 2 * world_size, - out.shape[seq_dim] // (2 * world_size), - *out.shape[(seq_dim + 1) :], - ) - out = out.index_select(seq_dim, seq_idx) - out_ = out_.view( - *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + fp8_context = nullcontext() + + # run attention + max_logit_ = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - - elif qkv_format == "thd": + if config.return_max_logit: + out_, max_logit_ = out_ + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) if is_training: - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 - ) else: - out = out.index_select(0, seq_idx_q).contiguous() - out_ = out_ - - atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] - names_cp = [x + "_cp" for x in names] - names_no_cp = [x + "_no_cp" for x in names] - is_fp8 = dtype == "fp8" - for i, t in enumerate(tensors_no_cp): - if t is not None: - if "softmax_offset" not in names[i] and "max_logit" not in names[i]: - if qkv_format == "bshd": - # Compare the two sequence chunks separately - # Compare dbias - if names[i] == "dbias": - # Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 1 (the split sequence dimension) - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - elif qkv_format == "sbhd": - # Compare the two sequence chunks separately - # Compare dbias (same as BSHD) - if names[i] == "dbias": - # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None + + # get outputs + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[5] = tensors_to_deq + for i, tensor in enumerate(tensors): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors + + ############ compare results between CP and no-CP ############ + if qkv_format == "bshd" or qkv_format == "sbhd": + if is_training: + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [dq, dk, dv, out] + ] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] + ] + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], + 2 * world_size, + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], + ) + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + ) + + elif qkv_format == "thd": + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] + ] + ).item() + == 0 ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 0 (the split sequence dimension) - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 ) + else: + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ + + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: + if qkv_format == "bshd": + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 ) - elif qkv_format == "thd": + else: compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 ) - else: - compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 - ) - logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - - # Destroy per-config communication groups so they don't leak into the next - # config in batch mode. The global process group is torn down by main(). - dist.destroy_process_group(cp_comm_group) - if cp_comm_type == "a2a+p2p": + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + finally: + # Destroy per-config groups so they survive exceptions in batch mode. + # The global process group is torn down by main(). + dist.destroy_process_group(cp_comm_group) for sg in cp_comm_sub_groups: dist.destroy_process_group(sg) From 3c08a7d0122049427280069b66ea802a4288e639 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 22:12:07 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index ec00bc01eb..3b3cfe87a5 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -256,7 +256,9 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( @@ -360,7 +362,9 @@ def run_dpa_with_cp( ############ run without CP ############ logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) else: fp8_context = nullcontext() max_logit = None @@ -418,7 +422,8 @@ def run_dpa_with_cp( seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + for x in [q_, k_, v_, dout_] ] elif qkv_format == "thd": seq_idx_q = tex.thd_get_partitioned_indices( @@ -475,7 +480,9 @@ def run_dpa_with_cp( if dtype == "fp8": core_attn.fp8_initialized = False core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) else: fp8_context = nullcontext() @@ -561,7 +568,10 @@ def run_dpa_with_cp( seq_kv_size = dbias.shape[-1] # Reshape to split seq_q dimension dbias = dbias.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, ) # Index select on the newly created dimension (now at position seq_q_dim) dbias = dbias.index_select(seq_q_dim, seq_idx) @@ -599,9 +609,9 @@ def run_dpa_with_cp( num_pads_q[b] == 0 or torch.count_nonzero( x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] ] ).item() == 0 @@ -750,7 +760,14 @@ def run_dpa_with_cp( ) elif qkv_format == "thd": compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + t, + tensors_cp[i], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) else: compare_and_assert( From 1cb19172eb6189beb7fdf42113b386b6c5374517 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 7 May 2026 15:21:37 -0700 Subject: [PATCH 5/7] [PyTorch] Broaden launch-error catch in CP batch dispatch Widen the except in _run_batch_once from AssertionError to Exception so OS-level failures from subprocess.run (FileNotFoundError when the worker script is missing, PermissionError, OSError when fds are exhausted, etc.) are attributed to the batch they came from instead of escaping the session-scoped _cp_batch_results fixture and ERROR-ing every CP test in the run. Addresses Greptile P1 review comment on PR 2965. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index fb2f77a689..d0d38689af 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -174,8 +174,13 @@ def _run_batch_once(num_gpus, configs): launch_err = None try: run_distributed(argv) - except AssertionError as exc: - launch_err = str(exc) + except Exception as exc: + # Catch broadly: subprocess.run can raise OSError/FileNotFoundError + # in addition to the AssertionError that run_distributed wraps a + # non-zero exit in. Letting any of these propagate would tear down + # the session fixture and ERROR every batched test instead of + # marking just this batch's configs as failed. + launch_err = str(exc) or repr(exc) try: with open(results_path, "r") as f: From 9adbbdcaaf9dca0b62e95fe47cb135afe237ddc8 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 8 May 2026 09:37:03 -0700 Subject: [PATCH 6/7] [PyTorch] Fix FP8 cascade failures and skip divergence in CP batch tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FP8GlobalStateManager retains quantizer registrations that reference destroyed NCCL process groups, causing cascade failures when multiple FP8 configs run in a single torchrun batch. Reset the singleton between configs to prevent this. get_available_attention_backends is stateful — calling it during the dry-run collect phase can produce different results than during the execute phase, causing "skip divergence" where the batch collects configs that should have been skipped. Cache backend availability per test node ID so the decision is consistent across phases. Also: pass MASTER_PORT through to torchrun so parallel pytest invocations on different GPU sets don't collide, and add [CP-BATCH] progress logging to the batch infrastructure. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 2 + .../attention/test_attention_with_cp.py | 85 ++++++++++++++----- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3b3cfe87a5..deab0c480c 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -24,6 +24,7 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, @@ -859,6 +860,7 @@ def _flush_results(): os.replace(tmp_path, results_path) for cfg in configs: + FP8GlobalStateManager.reset() for env_key in _TRANSIENT_ENV_KEYS: os.environ.pop(env_key, None) ok, err = _run_single_config(cfg) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index d0d38689af..efed582262 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -70,6 +70,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), ] + if "MASTER_PORT" in os.environ: + args.append("--master-port=" + os.environ["MASTER_PORT"]) te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) @@ -134,6 +136,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): # Module-level state used by the session fixture's collect phase. _COLLECT_MODE = False _COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) +_BACKEND_CACHE = {} # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): @@ -287,8 +290,11 @@ def _cp_batch_results(request): it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) ] + import time as _time + _COLLECTED_KWARGS.clear() _COLLECT_MODE = True + _t0 = _time.monotonic() try: for item in items: if _item_static_skip(item): @@ -304,6 +310,11 @@ def _cp_batch_results(request): pass finally: _COLLECT_MODE = False + print( + f"\n[CP-BATCH] Collect done: {len(_COLLECTED_KWARGS)} configs from" + f" {len(items)} items in {_time.monotonic() - _t0:.1f}s", + flush=True, + ) by_num_gpus = {} for nodeid, kwargs in _COLLECTED_KWARGS.items(): @@ -312,11 +323,29 @@ def _cp_batch_results(request): results = {} for num_gpus, entries in by_num_gpus.items(): - for start in range(0, len(entries), _BATCH_SIZE): + n_batches = (len(entries) + _BATCH_SIZE - 1) // _BATCH_SIZE + for batch_idx, start in enumerate(range(0, len(entries), _BATCH_SIZE)): chunk = entries[start : start + _BATCH_SIZE] + print( + f"[CP-BATCH] Running batch {batch_idx + 1}/{n_batches}" + f" ({len(chunk)} cfgs, {num_gpus} GPUs)...", + flush=True, + ) + _bt = _time.monotonic() chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + ok = sum(1 for r in chunk_results if r.get("ok")) + print( + f"[CP-BATCH] => {ok}/{len(chunk)} passed" + f" in {_time.monotonic() - _bt:.1f}s", + flush=True, + ) for (nodeid, _), res in zip(chunk, chunk_results): results[nodeid] = res + print( + f"[CP-BATCH] All batches done: {len(results)} results" + f" in {_time.monotonic() - _t0:.1f}s total", + flush=True, + ) return results @@ -378,12 +407,17 @@ def test_cp_with_flash_attention( if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} - available_backends, *_ = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype], - qkv_layout="_".join([qkv_format] * 3), - ) - flash_attn_supported, *_ = available_backends + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + flash_attn_supported = _BACKEND_CACHE[nodeid] + else: + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + _BACKEND_CACHE[nodeid] = flash_attn_supported if not flash_attn_supported: pytest.skip("No attention backend available.") @@ -634,23 +668,12 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, - qkv_layout="_".join([qkv_format] * 3), - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - deterministic=_deterministic, - ) - - _, fused_attn_supported, _ = available_backends - if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: - config_copy = copy.deepcopy(config) - config_copy.context_parallel = False - config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + fused_attn_supported = _BACKEND_CACHE[nodeid] + else: available_backends, _, fused_attn_backends = get_available_attention_backends( - config_copy, + config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, @@ -658,7 +681,23 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends + if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: + config_copy = copy.deepcopy(config) + config_copy.context_parallel = False + config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + available_backends, _, fused_attn_backends = get_available_attention_backends( + config_copy, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported, _ = available_backends + _BACKEND_CACHE[nodeid] = fused_attn_supported if not fused_attn_supported: pytest.skip("No attention backend available.") From 97fcd4cd96f9a83a708f42c554e976a3f2406a93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 19:38:50 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index efed582262..176b36b7d8 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -136,7 +136,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): # Module-level state used by the session fixture's collect phase. _COLLECT_MODE = False _COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) -_BACKEND_CACHE = {} # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) +_BACKEND_CACHE = ( + {} +) # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): @@ -335,8 +337,7 @@ def _cp_batch_results(request): chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) ok = sum(1 for r in chunk_results if r.get("ok")) print( - f"[CP-BATCH] => {ok}/{len(chunk)} passed" - f" in {_time.monotonic() - _bt:.1f}s", + f"[CP-BATCH] => {ok}/{len(chunk)} passed in {_time.monotonic() - _bt:.1f}s", flush=True, ) for (nodeid, _), res in zip(chunk, chunk_results):