[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965sudhakarsingh27 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR amortises NCCL init/destroy overhead for CP attention tests by introducing a session-scoped pytest fixture that groups parametrized tests into batches of 16, launches one
Confidence Score: 3/5Safe to merge for the batching infrastructure itself; however run_dpa_with_cp creates NCCL communicators without a try/finally guard, so any mid-function exception in a batch config leaks those communicators for the rest of the torchrun lifetime. The cp_comm_group and a2a+p2p sub-groups are created early in run_dpa_with_cp and destroyed only at the very end of the function with no try/finally. In single-config mode this was a cosmetic resource issue; in the new batch mode a single flaky config can orphan communicators that accumulate silently across the 16-config window, potentially exhausting NCCL's communicator table and corrupting subsequent configs with opaque errors. tests/pytorch/attention/run_attention_with_cp.py — specifically the cp_comm_group lifecycle inside run_dpa_with_cp. Important Files Changed
Sequence DiagramsequenceDiagram
participant pytest
participant fixture as _cp_batch_results (session)
participant dry as Dry-run phase
participant batch as _run_one_batch
participant torchrun as torchrun subprocess
participant worker as run_attention_with_cp.py (worker)
pytest->>fixture: session setup
fixture->>dry: "iterate items, _COLLECT_MODE=True"
dry->>dry: call test body with _DummyRequest
dry-->>fixture: pytest.skip discarded / ok stashed
fixture->>fixture: group by num_gpus, chunk by BATCH_SIZE
loop each chunk
fixture->>batch: _run_one_batch(num_gpus, configs)
batch->>torchrun: launch (one NCCL world)
torchrun->>worker: _init_distributed() once
loop each config in batch
worker->>worker: _run_single_config(cfg)
worker->>worker: dist.all_reduce(ok, MIN) across ranks
worker->>worker: flush results JSON atomically
worker->>worker: dist.barrier()
end
torchrun-->>batch: exit
batch->>batch: read results.json
batch-->>fixture: results list
end
fixture-->>pytest: results dict by nodeid
pytest->>pytest: execute each test
pytest->>pytest: _run_or_fetch lookup result
Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
fa189b0 to
0e9fc1f
Compare
7802ec5 to
c80df5d
Compare
| try: | ||
| argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) | ||
| launch_err = None | ||
| try: | ||
| run_distributed(argv) | ||
| except AssertionError as exc: | ||
| launch_err = str(exc) |
There was a problem hiding this comment.
Only
AssertionError caught from run_distributed
subprocess.run inside run_distributed can raise FileNotFoundError, PermissionError, or OSError for OS-level failures (missing executable, exhausted file descriptors, etc.). These propagate uncaught through _run_batch_once → _run_one_batch → _cp_batch_results. Because the fixture is session-scoped, one such exception causes every test that depends on _cp_batch_results to surface as a fixture ERROR rather than an individual test failure. In the original code, the same OS error would fail only the one test that triggered it. Widening the except to except (AssertionError, Exception) before reading the results file would preserve the per-batch isolation benefit.
…L init
Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.
This change introduces a session-scoped fixture that:
1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
a kwargs dict for the worker.
2. Groups runnable configs by ``num_gpus`` and chunks them into batches
of CP_TEST_BATCH_SIZE (default 16).
3. Launches one torchrun per chunk; the worker initialises NCCL once
and runs all configs in the chunk inside the same world.
Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.
Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1db76b7 to
6355f62
Compare
for more information, see https://pre-commit.ci
| dist.destroy_process_group(cp_comm_group) | ||
| if cp_comm_type == "a2a+p2p": | ||
| for sg in cp_comm_sub_groups: | ||
| dist.destroy_process_group(sg) |
There was a problem hiding this comment.
NCCL communicator leak on exception mid-function
cp_comm_group is created at line 238 (and up to 4 sub-groups for a2a+p2p at lines 248-250) but the destroy calls are at the very bottom of the function with no try/finally. Any exception that fires in between — a CUDA OOM, a comparison mismatch, a BaseException from cuDNN — causes _run_single_config to catch it and return (False, traceback), while the communicators are never cleaned up.
In batch mode the problem compounds: with 16 configs per torchrun and any flaky configs, leaked communicators accumulate across the whole batch. NCCL's internal communicator table has a fixed limit (typically 128), so a few hundred batched configs with occasional failures can exhaust it and corrupt subsequent configs with opaque "NCCL error: invalid usage" rather than surfacing the original failure. Wrapping the body after group creation in a try/finally guarantees cleanup.
Design
Problem
Each parametrized CP attention test spawns its own
torchrunand pays 5–15 s of NCCL init/destroy. ~650–800 collected tests → 1.5–3 h of pure setup overhead. We need that overhead amortised without changing how tests are written or how skips are reported.Approach
Run multiple configs inside one
torchrunthat shares a single NCCL world. A session-scoped fixture (_cp_batch_results) does two passes:_cp_batch_results, call its test function directly with a stubbedrequest. The body executes its inlinepytest.skip(...)checks normally; if any fires, the item is dropped from the batch. Otherwise the body's final call to_run_or_fetch(...)records its kwargs in a module-level dict instead of launching a subprocess.num_gpus_per_node, chunk into batches ofCP_TEST_BATCH_SIZE(default 16), launch onetorchrunper chunk. Worker (run_attention_with_cp.py) inits NCCL once, loops over configs, atomically flushes per-config results to<batch>.results.json. When pytest later runs each test for real, the body re-evaluates skips and_run_or_fetchlooks up the recorded result.The dry-run mechanism in detail
Batching needs to know upfront — before any test body runs in pytest's normal execute phase — which configs are runnable, what kwargs each needs, and how they should be grouped (by
num_gpus_per_node). The dry-run is how we extract that information from the test bodies without duplicating their skip logic in a separate registry.What runs
_dry_run_itemis the actual call into the test body:This bypasses pytest's normal runner (
runtest_protocol) entirely — no fixture setup hooks, no plugin reporters, no captured-stdout machinery. Just calls the underlying function with the same parametrize values pytest would have passed.How
_COLLECT_MODEworks_run_or_fetchchecks a module-level flag:So a single function serves both phases: in collect mode it's a recorder, in execute mode it's a result-fetcher. The test body doesn't know which mode it's in — it just calls one helper at the end.
A global flag is the simplest signal that survives across the function-call boundary. A context manager / contextvar would also work; we don't need one because
_dry_run_itemis called serially in one thread.What gets stubbed
The test bodies declare two fixtures:
requestand_cp_batch_results. During dry-run we provide both as stubs:request_DummyRequest(nodeid)— has onlyrequest.node.nodeid_run_or_fetchonly readsnodeid(to key_COLLECTED_KWARGS); test body itself never touchesrequest._cp_batch_results{}(empty dict)_run_or_fetchreturns early in collect mode, never inspectsbatch_results.If a future test ever uses
request.config.getoption(...)etc., either the stub gets extended or the test moves to a more elaborate dry-run path. The dry-run loop already catchesBaseExceptionto keep one over-eager test from poisoning the whole fixture.How skip checks survive
Inline
pytest.skip("reason")raises_pytest.outcomes.Skipped(exposed aspytest.skip.Exception). The dry-run loop catches this per-item, drops the item from the batch (no kwargs recorded), and moves on. In execute mode the same line raises again and pytest's runner reports the test as SKIPPED with the same reason. Neither phase touches_run_or_fetchfor skipped items.@pytest.mark.skipand@pytest.mark.skipif(<bool_condition>)markers are not raised when callingitem.function(...)directly — pytest evaluates them in its runner, which we're bypassing._item_static_skip(item)walksitem.iter_markers("skip"|"skipif")and readsmarker.args[0](the condition) before the dry-run, skipping items the markers would otherwise drop. This avoids queuing a config that pytest will then refuse to run anyway.A
skipifwith a string-expression condition or a runtime-evaluated condition would slip past_item_static_skip(we don'teval()strings). Such items would still be skipped correctly in execute mode — we'd just have wasted one slot in atorchrunbatch.What the two phases produce
After dry-run:
_COLLECTED_KWARGS:{ nodeid → {num_gpus, dtype, model, qkv_format, ...} }for every non-skipped item.pytest.skip(...)call.The fixture then groups by
num_gpus, chunks into batches, runs_run_one_batch(num_gpus, [kwargs, ...])for each chunk, and assembles theresultsdict ({ nodeid → {ok, error} }).In execute mode pytest runs each test as it normally would. The body re-evaluates skip checks (cheap, deterministic), and if no skip fires it calls
_run_or_fetchwhich now does the dict lookup and either returns (PASS) or raisesAssertionError(FAIL).Why two evaluations of the skip logic is safe
Both phases call the same Python code on the same parametrize args. The only inputs that could differ between the two are:
model_configs_*[model].context_parallel = True) — but every iteration sets it the same way, so the value is idempotent;get_available_attention_backends) — if this changes between collect and execute, a config that was queued might skip at execute time. Harmless: the recorded result goes unused.If a skip fires only at execute time (was not present in dry-run), the lookup in
_run_or_fetchis short-circuited because the body already raisedpytest.skip(...)before reaching it. If somehow the test body reached_run_or_fetchwithout a result (collection mismatch), the helper falls back topytest.skip("No batched result recorded (collection mismatch).")rather than asserting on missing data.Cost
For every parametrized item that requests
_cp_batch_results, the body runs once during the session fixture (collect mode) and once again during pytest's normal execute. Skip checks are pure Python; the only non-trivial work isget_available_attention_backends, which the original non-batched flow also called once per test in the pytest process. Net cost: one extraget_available_attention_backendscall per item. Negligible compared to the NCCL setup time the batching saves.Why this shape
_prepare_*helpers+request, _cp_batch_resultsin the signature andrun_distributed(get_bash_arguments(...))→_run_or_fetch(request, _cp_batch_results, ...)at the end. Inlinepytest.skip(...)calls and the@pytest.mark.parametrizestack are unchanged. One source of truth for skip logic.request(only.node.nodeid)_run_or_fetchreads. A 4-line shim avoids dragging in pytest's full fixture machinery during collect.@pytest.mark.skip(if)markers honoured up frontitem.function(...)call, so we explicitly check them in_item_static_skip(item)before the dry-run; otherwise we'd queue marker-skipped configs fortorchrunand waste cycles.tmp + os.replace)dist.all_reduce(ok, op=MIN)after each configokto False._run_one_batchre-runs each in its owntorchrunto identify the actual culprit and salvage real results for innocent neighbours. Disable viaCP_TEST_BATCH_RETRY=0.cp_comm_group,a2a+p2psub-groups) are destroyed at the end ofrun_dpa_with_cpso they don't leak across configs in the same world.copy.deepcopy(model_configs_*[model])in workerattn_mask_type; without copy the next config in the batch reads a mutated entry._TRANSIENT_ENV_KEYSbetween configsNVTE_FLASH_ATTN,NVTE_FUSED_ATTN,NVTE_FP8_DPA_BWD,NVTE_DPA_FP8CS_O_in_F16,NVTE_ALLOW_NONDETERMINISTIC_ALGOare set per-config — clearing them prevents leakage that would alter backend selection in the next config.arg.split("=", 1)=(paths) survive.Equivalence to the original flow
For any given parametrised test, the externally-visible behaviour matches the non-batched flow:
pytest.skip(...)+run_distributed(get_bash_arguments(...))pytest.skip(...)+_run_or_fetch(...)pytest.skip(reason)run_distributed)AssertionErrorpytest -k,--collect-only,--co --tb@pytest.mark.skip(if)markersWhere the flow can diverge from the original
These are real-world edge cases where the batched path could behave differently. Each is documented; most are mitigated.
request— including its skip checks and anyget_available_attention_backends(...)calls. Effects:get_available_attention_backendsmay probe CUDA / cuDNN. Same probe runs in execute mode, so this is duplicated work, not a behaviour change.model_configs_*[model]is mutated in place (e.g.config.context_parallel = True). Original tests do this too, so no divergence — but a future test that mutates something else between collect and execute could see the residual state. Mitigation: workerrun_dpa_with_cpdoescopy.deepcopy(model_configs_*[model])before mutating.requestAPI surface during dry-run. Onlyrequest.node.nodeidis provided. A future test that usesrequest.config.getoption(...)orrequest.getfixturevalue(...)wouldAttributeErrorduring dry-run. Mitigation: the fixture catchesBaseExceptionfrom dry-run and lets the same error fire in execute mode where pytest's realrequestis available._TRANSIENT_ENV_KEYSreset,torch.cuda.empty_cache(),dist.barrier(), RNG re-seed (1234) at start of each config,copy.deepcopyof shared config dicts. Anything not on that list is a potential leak.CP_TEST_BATCH_RETRY=0.run_dpa_with_cpraises,torchrunreturns non-zero and pytest fails the test. This PR: only rank 0 writes the JSON; cross-rankdist.all_reduce(ok, op=MIN)after each config matches the original's "any rank failure ⇒ test fails" semantics.@pytest.mark.skipif(condition_evaluated_at_runtime). Askipifwhose condition becomes True only at execute time (e.g. depends on a fixture) would not be detected by the static_item_static_skipcheck. The condition would still fire in execute mode and pytest would skip — but we'd have wasted onetorchrunslot for it. None of the current markers behave this way.get_available_attention_backendsnon-determinism. If this returns different values between dry-run and execute (driver state changes, etc.), a config could be queued by collect but skip in execute. Harmless:_run_or_fetchis never reached, the unused batch result is garbage-collected.item.function,item.callspec.params, andpytest.skip.Exception. These are stable in pytest 7+/8+. If they shift in a future major version,_dry_run_itemis a 3-line shim to update.In short: the externally-visible behaviour is equivalent for any test that already worked with the non-batched flow. The places it could legitimately diverge are listed above and either mitigated in code or contained behind an env var.
Knobs
CP_TEST_BATCH_SIZE=Ntorchrun. Default 16. Set 1 to bisect.CP_TEST_BATCH_RETRY=0Adding a new batched test
@pytest.mark.parametrizestack + inlinepytest.skip(...)checks.request, _cp_batch_resultsto the function signature.run_distributed(get_bash_arguments(...))with_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)(kwargs become the worker'srun_dpa_with_cp(**kwargs)arguments).That's the entire wiring. No registry, no signature mirroring, no preparer.
Failure semantics
pytest.skip(...)fires_run_or_fetch).@pytest.mark.skip(if)marker firestorchrun).all_reduce).Validation
Local H100 (8 GPU),
test_essential = True:cp_3_1(MLA) + non-P2P fused-attention failures reproduce on origin/main with non-batched test files; not introduced by this PR.Stress (no regressions):
-k <no-match>: 10272 deselected, fixture not invoked.--collect-only: 10272 collected, notorchrun.-k <small subset>: 7 passed / 383 skipped.CP_TEST_BATCH_SIZE=1: 6 passed / 254 skipped.Files
tests/pytorch/attention/test_attention_with_cp.py— collect/dispatch/fetch infra, dry-run helpers, test bodies updated minimally.tests/pytorch/attention/run_attention_with_cp.py—_init_distributed,main()batch mode, atomic per-config flush, cross-rank aggregation, per-config group teardown,copy.deepcopyof model configs, transient env reset,split("=", 1).Type of change
Checklist