Skip to content

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965

Open
sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr
Open

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented May 6, 2026

Design

Problem

Each parametrized CP attention test spawns its own torchrun and pays 5–15 s of NCCL init/destroy. ~650–800 collected tests → 1.5–3 h of pure setup overhead. We need that overhead amortised without changing how tests are written or how skips are reported.

Approach

Run multiple configs inside one torchrun that shares a single NCCL world. A session-scoped fixture (_cp_batch_results) does two passes:

  1. Collect (dry-run, in-process). Walk pytest's collected items. For each item that requests _cp_batch_results, call its test function directly with a stubbed request. The body executes its inline pytest.skip(...) checks normally; if any fires, the item is dropped from the batch. Otherwise the body's final call to _run_or_fetch(...) records its kwargs in a module-level dict instead of launching a subprocess.
  2. Batch + execute. Group recorded kwargs by num_gpus_per_node, chunk into batches of CP_TEST_BATCH_SIZE (default 16), launch one torchrun per chunk. Worker (run_attention_with_cp.py) inits NCCL once, loops over configs, atomically flushes per-config results to <batch>.results.json. When pytest later runs each test for real, the body re-evaluates skips and _run_or_fetch looks up the recorded result.

The dry-run mechanism in detail

Batching needs to know upfront — before any test body runs in pytest's normal execute phase — which configs are runnable, what kwargs each needs, and how they should be grouped (by num_gpus_per_node). The dry-run is how we extract that information from the test bodies without duplicating their skip logic in a separate registry.

What runs

@pytest.fixture(scope="session")
def _cp_batch_results(request):
    items = [it for it in request.session.items
             if "_cp_batch_results" in getattr(it, "fixturenames", ())]
    _COLLECT_MODE = True
    for item in items:
        if _item_static_skip(item):
            continue
        try:
            _dry_run_item(item)
        except pytest.skip.Exception:
            pass
        except BaseException:
            pass  # surfaces in execute mode as a normal pytest error
    _COLLECT_MODE = False
    # group _COLLECTED_KWARGS by num_gpus, chunk, run torchrun batches

_dry_run_item is the actual call into the test body:

def _dry_run_item(item):
    func = item.function
    params = dict(item.callspec.params)
    func(_DummyRequest(item.nodeid), {}, **params)

This bypasses pytest's normal runner (runtest_protocol) entirely — no fixture setup hooks, no plugin reporters, no captured-stdout machinery. Just calls the underlying function with the same parametrize values pytest would have passed.

How _COLLECT_MODE works

_run_or_fetch checks a module-level flag:

def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs):
    if _COLLECT_MODE:
        _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs)
        return  # never reaches the lookup; never asserts
    entry = batch_results.get(request.node.nodeid)
    ...

So a single function serves both phases: in collect mode it's a recorder, in execute mode it's a result-fetcher. The test body doesn't know which mode it's in — it just calls one helper at the end.

A global flag is the simplest signal that survives across the function-call boundary. A context manager / contextvar would also work; we don't need one because _dry_run_item is called serially in one thread.

What gets stubbed

The test bodies declare two fixtures: request and _cp_batch_results. During dry-run we provide both as stubs:

Param Stub Why this works
request _DummyRequest(nodeid) — has only request.node.nodeid _run_or_fetch only reads nodeid (to key _COLLECTED_KWARGS); test body itself never touches request.
_cp_batch_results {} (empty dict) _run_or_fetch returns early in collect mode, never inspects batch_results.

If a future test ever uses request.config.getoption(...) etc., either the stub gets extended or the test moves to a more elaborate dry-run path. The dry-run loop already catches BaseException to keep one over-eager test from poisoning the whole fixture.

How skip checks survive

Inline pytest.skip("reason") raises _pytest.outcomes.Skipped (exposed as pytest.skip.Exception). The dry-run loop catches this per-item, drops the item from the batch (no kwargs recorded), and moves on. In execute mode the same line raises again and pytest's runner reports the test as SKIPPED with the same reason. Neither phase touches _run_or_fetch for skipped items.

@pytest.mark.skip and @pytest.mark.skipif(<bool_condition>) markers are not raised when calling item.function(...) directly — pytest evaluates them in its runner, which we're bypassing. _item_static_skip(item) walks item.iter_markers("skip"|"skipif") and reads marker.args[0] (the condition) before the dry-run, skipping items the markers would otherwise drop. This avoids queuing a config that pytest will then refuse to run anyway.

A skipif with a string-expression condition or a runtime-evaluated condition would slip past _item_static_skip (we don't eval() strings). Such items would still be skipped correctly in execute mode — we'd just have wasted one slot in a torchrun batch.

What the two phases produce

After dry-run:

  • _COLLECTED_KWARGS: { nodeid → {num_gpus, dtype, model, qkv_format, ...} } for every non-skipped item.
  • Skipped items leave no trace anywhere — they re-skip in execute mode via the same pytest.skip(...) call.

The fixture then groups by num_gpus, chunks into batches, runs _run_one_batch(num_gpus, [kwargs, ...]) for each chunk, and assembles the results dict ({ nodeid → {ok, error} }).

In execute mode pytest runs each test as it normally would. The body re-evaluates skip checks (cheap, deterministic), and if no skip fires it calls _run_or_fetch which now does the dict lookup and either returns (PASS) or raises AssertionError (FAIL).

Why two evaluations of the skip logic is safe

Both phases call the same Python code on the same parametrize args. The only inputs that could differ between the two are:

  • module-level state mutated by other dry-run iterations (e.g. model_configs_*[model].context_parallel = True) — but every iteration sets it the same way, so the value is idempotent;
  • external state (driver / cuDNN probes via get_available_attention_backends) — if this changes between collect and execute, a config that was queued might skip at execute time. Harmless: the recorded result goes unused.

If a skip fires only at execute time (was not present in dry-run), the lookup in _run_or_fetch is short-circuited because the body already raised pytest.skip(...) before reaching it. If somehow the test body reached _run_or_fetch without a result (collection mismatch), the helper falls back to pytest.skip("No batched result recorded (collection mismatch).") rather than asserting on missing data.

Cost

For every parametrized item that requests _cp_batch_results, the body runs once during the session fixture (collect mode) and once again during pytest's normal execute. Skip checks are pure Python; the only non-trivial work is get_available_attention_backends, which the original non-batched flow also called once per test in the pytest process. Net cost: one extra get_available_attention_backends call per item. Negligible compared to the NCCL setup time the batching saves.

Why this shape

Decision Rationale
Dry-run instead of _prepare_* helpers Test bodies stay textually identical to non-batched tests. The diff is +request, _cp_batch_results in the signature and run_distributed(get_bash_arguments(...))_run_or_fetch(request, _cp_batch_results, ...) at the end. Inline pytest.skip(...) calls and the @pytest.mark.parametrize stack are unchanged. One source of truth for skip logic.
Stub request (only .node.nodeid) That's the only attribute _run_or_fetch reads. A 4-line shim avoids dragging in pytest's full fixture machinery during collect.
@pytest.mark.skip(if) markers honoured up front Static markers don't fire during a direct item.function(...) call, so we explicitly check them in _item_static_skip(item) before the dry-run; otherwise we'd queue marker-skipped configs for torchrun and waste cycles.
Atomic per-config flush (tmp + os.replace) A reader (the pytest driver) never sees a partial JSON. A worker crash mid-config preserves earlier results.
Cross-rank dist.all_reduce(ok, op=MIN) after each config Only rank 0 writes the JSON. Without aggregation, a per-partition assertion that fires on rank > 0 (CP comparisons run independently per rank) would be silently swallowed. The reduce makes any rank's failure flip the recorded ok to False.
Auto-retry crashed batch entries as singletons Worker crash (segfault / NCCL hang / OOM) before flush leaves un-attributed entries. _run_one_batch re-runs each in its own torchrun to identify the actual culprit and salvage real results for innocent neighbours. Disable via CP_TEST_BATCH_RETRY=0.
Per-config NCCL communication group teardown Sub-groups (cp_comm_group, a2a+p2p sub-groups) are destroyed at the end of run_dpa_with_cp so they don't leak across configs in the same world.
copy.deepcopy(model_configs_*[model]) in worker THD path mutates attn_mask_type; without copy the next config in the batch reads a mutated entry.
Reset _TRANSIENT_ENV_KEYS between configs NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FP8_DPA_BWD, NVTE_DPA_FP8CS_O_in_F16, NVTE_ALLOW_NONDETERMINISTIC_ALGO are set per-config — clearing them prevents leakage that would alter backend selection in the next config.
arg.split("=", 1) Lets values containing = (paths) survive.

Equivalence to the original flow

For any given parametrised test, the externally-visible behaviour matches the non-batched flow:

What the user sees Original This PR
Test body source inline pytest.skip(...) + run_distributed(get_bash_arguments(...)) inline pytest.skip(...) + _run_or_fetch(...)
Skip reason in pytest output from inline pytest.skip(reason) same reason, same call site
Pass / fail report from worker's exit code / assertion from worker's per-config JSON entry
Failure traceback full traceback from worker stderr (via run_distributed) full traceback captured per-config in the worker, surfaced as AssertionError
pytest -k, --collect-only, --co --tb works works
@pytest.mark.skip(if) markers honoured honoured (explicit pre-check before dry-run)

Where the flow can diverge from the original

These are real-world edge cases where the batched path could behave differently. Each is documented; most are mitigated.

  1. Dry-run side effects. The collect phase invokes the test body once in-process with a stub request — including its skip checks and any get_available_attention_backends(...) calls. Effects:
    • Modules import again (cheap).
    • get_available_attention_backends may probe CUDA / cuDNN. Same probe runs in execute mode, so this is duplicated work, not a behaviour change.
    • model_configs_*[model] is mutated in place (e.g. config.context_parallel = True). Original tests do this too, so no divergence — but a future test that mutates something else between collect and execute could see the residual state. Mitigation: worker run_dpa_with_cp does copy.deepcopy(model_configs_*[model]) before mutating.
  2. request API surface during dry-run. Only request.node.nodeid is provided. A future test that uses request.config.getoption(...) or request.getfixturevalue(...) would AttributeError during dry-run. Mitigation: the fixture catches BaseException from dry-run and lets the same error fire in execute mode where pytest's real request is available.
  3. Cross-config state leakage in the worker. All configs in a batch share one Python process and one NCCL world. Anything the original test relied on starting from a clean process state may behave differently. Mitigations applied: per-config NCCL sub-group destruction, _TRANSIENT_ENV_KEYS reset, torch.cuda.empty_cache(), dist.barrier(), RNG re-seed (1234) at start of each config, copy.deepcopy of shared config dicts. Anything not on that list is a potential leak.
  4. Worker crash mid-batch. Original: only the affected test fails; remaining tests run independently. This PR: configs after the crash that didn't get flushed get marker entries, then are retried as singletons (so each gets a real attribution). The retry adds NCCL init/destroy overhead per affected config. Disable via CP_TEST_BATCH_RETRY=0.
  5. Per-rank failures. Original: if any rank's run_dpa_with_cp raises, torchrun returns non-zero and pytest fails the test. This PR: only rank 0 writes the JSON; cross-rank dist.all_reduce(ok, op=MIN) after each config matches the original's "any rank failure ⇒ test fails" semantics.
  6. @pytest.mark.skipif(condition_evaluated_at_runtime). A skipif whose condition becomes True only at execute time (e.g. depends on a fixture) would not be detected by the static _item_static_skip check. The condition would still fire in execute mode and pytest would skip — but we'd have wasted one torchrun slot for it. None of the current markers behave this way.
  7. get_available_attention_backends non-determinism. If this returns different values between dry-run and execute (driver state changes, etc.), a config could be queued by collect but skip in execute. Harmless: _run_or_fetch is never reached, the unused batch result is garbage-collected.
  8. Pytest internals. The dry-run uses item.function, item.callspec.params, and pytest.skip.Exception. These are stable in pytest 7+/8+. If they shift in a future major version, _dry_run_item is a 3-line shim to update.

In short: the externally-visible behaviour is equivalent for any test that already worked with the non-batched flow. The places it could legitimately diverge are listed above and either mitigated in code or contained behind an env var.

Knobs

Env var Effect
CP_TEST_BATCH_SIZE=N Configs per torchrun. Default 16. Set 1 to bisect.
CP_TEST_BATCH_RETRY=0 Disable singleton retry for unattributed crashes.

Adding a new batched test

  1. Write the test the way you would write a non-batched CP test: @pytest.mark.parametrize stack + inline pytest.skip(...) checks.
  2. Add request, _cp_batch_results to the function signature.
  3. Replace the trailing run_distributed(get_bash_arguments(...)) with _run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...) (kwargs become the worker's run_dpa_with_cp(**kwargs) arguments).

That's the entire wiring. No registry, no signature mirroring, no preparer.

Failure semantics

Outcome What pytest sees
Inline pytest.skip(...) fires Standard SKIP with the same reason (skip is re-evaluated in execute mode and short-circuits before _run_or_fetch).
@pytest.mark.skip(if) marker fires Standard SKIP via pytest's normal path (not queued for torchrun).
Config ran, assertion failed FAIL with worker's traceback.
Config ran on rank > 0 with assertion, rank 0 OK FAIL (cross-rank all_reduce).
Worker subprocess crashed before flush Each affected config retried as a singleton; real result wins, residual crashes surface as FAIL with attribution.
Dry-run itself raised Caught and ignored in the fixture; the same exception fires in execute mode and pytest reports it as a normal test ERROR.

Validation

Local H100 (8 GPU), test_essential = True:

  • 35 passed / 10234 skipped / 0 unrelated failures, 216 s.
  • 3 pre-existing cp_3_1 (MLA) + non-P2P fused-attention failures reproduce on origin/main with non-batched test files; not introduced by this PR.

Stress (no regressions):

  • single nodeid: 1 passed / 18 s.
  • -k <no-match>: 10272 deselected, fixture not invoked.
  • --collect-only: 10272 collected, no torchrun.
  • -k <small subset>: 7 passed / 383 skipped.
  • CP_TEST_BATCH_SIZE=1: 6 passed / 254 skipped.

Files

  • tests/pytorch/attention/test_attention_with_cp.py — collect/dispatch/fetch infra, dry-run helpers, test bodies updated minimally.
  • tests/pytorch/attention/run_attention_with_cp.py_init_distributed, main() batch mode, atomic per-config flush, cross-rank aggregation, per-config group teardown, copy.deepcopy of model configs, transient env reset, split("=", 1).

Type of change

  • Code refactoring (test infrastructure; no production-code change)

Checklist

  • Contributing guidelines followed
  • Functionality complete
  • Code commented where non-obvious
  • Documentation (n/a — internal test infra)
  • No new warnings
  • Existing test suite serves as input + validation
  • Existing tests pass locally

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR amortises NCCL init/destroy overhead for CP attention tests by introducing a session-scoped pytest fixture that groups parametrized tests into batches of 16, launches one torchrun per batch, and reads back per-config JSON results — cutting 1.5–3 hours of pure setup from the test suite. It also adds a small bugfix restricting the deterministic-FA3 disable condition to training + large head dims only.

  • run_attention_with_cp.py: Removes per-call dist.init/destroy from run_dpa_with_cp, adds a new main() batch loop that shares one NCCL world across configs, uses dist.all_reduce(ok, MIN) to surface non-rank-0 failures, and atomically flushes per-config results to JSON.
  • test_attention_with_cp.py: Adds _cp_batch_results (session fixture), dry-run collect phase, _run_or_fetch drop-in, _run_one_batch with singleton retry for unattributed crashes, and two new skip guards for unsupported deterministic+softmax combinations in test_cp_with_fused_attention.

Confidence Score: 3/5

Safe to merge for the batching infrastructure itself; however run_dpa_with_cp creates NCCL communicators without a try/finally guard, so any mid-function exception in a batch config leaks those communicators for the rest of the torchrun lifetime.

The cp_comm_group and a2a+p2p sub-groups are created early in run_dpa_with_cp and destroyed only at the very end of the function with no try/finally. In single-config mode this was a cosmetic resource issue; in the new batch mode a single flaky config can orphan communicators that accumulate silently across the 16-config window, potentially exhausting NCCL's communicator table and corrupting subsequent configs with opaque errors.

tests/pytorch/attention/run_attention_with_cp.py — specifically the cp_comm_group lifecycle inside run_dpa_with_cp.

Important Files Changed

Filename Overview
tests/pytorch/attention/run_attention_with_cp.py Adds batch-mode entry point with NCCL init/destroy refactored to main(); cp_comm_group cleanup is at the end of run_dpa_with_cp without try/finally, causing a communicator leak on any mid-function exception.
tests/pytorch/attention/test_attention_with_cp.py Introduces session-scoped _cp_batch_results fixture with collect/execute phases; batches tests by num_gpus with singleton retry for unattributed crashes; new skip guards added for deterministic+non-vanilla softmax combinations.

Sequence Diagram

sequenceDiagram
    participant pytest
    participant fixture as _cp_batch_results (session)
    participant dry as Dry-run phase
    participant batch as _run_one_batch
    participant torchrun as torchrun subprocess
    participant worker as run_attention_with_cp.py (worker)

    pytest->>fixture: session setup
    fixture->>dry: "iterate items, _COLLECT_MODE=True"
    dry->>dry: call test body with _DummyRequest
    dry-->>fixture: pytest.skip discarded / ok stashed
    fixture->>fixture: group by num_gpus, chunk by BATCH_SIZE
    loop each chunk
        fixture->>batch: _run_one_batch(num_gpus, configs)
        batch->>torchrun: launch (one NCCL world)
        torchrun->>worker: _init_distributed() once
        loop each config in batch
            worker->>worker: _run_single_config(cfg)
            worker->>worker: dist.all_reduce(ok, MIN) across ranks
            worker->>worker: flush results JSON atomically
            worker->>worker: dist.barrier()
        end
        torchrun-->>batch: exit
        batch->>batch: read results.json
        batch-->>fixture: results list
    end
    fixture-->>pytest: results dict by nodeid

    pytest->>pytest: execute each test
    pytest->>pytest: _run_or_fetch lookup result
Loading

Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/run_attention_with_cp.py
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from fa189b0 to 0e9fc1f Compare May 6, 2026 23:01
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch 4 times, most recently from 7802ec5 to c80df5d Compare May 7, 2026 13:57
Comment on lines +147 to +153
try:
argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path)
launch_err = None
try:
run_distributed(argv)
except AssertionError as exc:
launch_err = str(exc)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Only AssertionError caught from run_distributed

subprocess.run inside run_distributed can raise FileNotFoundError, PermissionError, or OSError for OS-level failures (missing executable, exhausted file descriptors, etc.). These propagate uncaught through _run_batch_once_run_one_batch_cp_batch_results. Because the fixture is session-scoped, one such exception causes every test that depends on _cp_batch_results to surface as a fixture ERROR rather than an individual test failure. In the original code, the same OS error would fail only the one test that triggered it. Widening the except to except (AssertionError, Exception) before reading the results file would preserve the per-batch isolation benefit.

…L init

Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.

This change introduces a session-scoped fixture that:
  1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
     a kwargs dict for the worker.
  2. Groups runnable configs by ``num_gpus`` and chunks them into batches
     of CP_TEST_BATCH_SIZE (default 16).
  3. Launches one torchrun per chunk; the worker initialises NCCL once
     and runs all configs in the chunk inside the same world.

Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.

Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 1db76b7 to 6355f62 Compare May 7, 2026 14:14
Comment on lines +762 to +765
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for sg in cp_comm_sub_groups:
dist.destroy_process_group(sg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NCCL communicator leak on exception mid-function

cp_comm_group is created at line 238 (and up to 4 sub-groups for a2a+p2p at lines 248-250) but the destroy calls are at the very bottom of the function with no try/finally. Any exception that fires in between — a CUDA OOM, a comparison mismatch, a BaseException from cuDNN — causes _run_single_config to catch it and return (False, traceback), while the communicators are never cleaned up.

In batch mode the problem compounds: with 16 configs per torchrun and any flaky configs, leaked communicators accumulate across the whole batch. NCCL's internal communicator table has a fixed limit (typically 128), so a few hundred batched configs with occasional failures can exhaust it and corrupt subsequent configs with opaque "NCCL error: invalid usage" rather than surfacing the original failure. Wrapping the body after group creation in a try/finally guarantees cleanup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant