Skip to content

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443

Open
denera wants to merge 38 commits intoNVIDIA:mainfrom
denera:common/tp-overlap-cublasmp
Open

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
denera wants to merge 38 commits intoNVIDIA:mainfrom
denera:common/tp-overlap-cublasmp

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12
Comment thread transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -17,6 +18,12 @@

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

@@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder {
ExtComm comm);

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 16, 2025

Greptile Summary

This PR integrates the cuBLASMp backend into the Comm+GEMM overlap API across both the PyTorch and JAX frameworks, adding new cuBLASMp-path constructors for CommOverlapCore/CommOverlapBase/CommOverlapP2PBase, NCCL communicator bootstrap logic in CommOverlapHelper, and dispatch guards throughout the overlap execution paths. Several bugs from prior review rounds have been addressed (concurrent backend launches, --use-cublasmp flag being discarded, ncclCommInitAllncclCommInitRank, torch.transpose missing args). A number of P1 findings from previous rounds remain unresolved in the current head — most critically the wrong get_tp_domain_id() used as tp_rank in the JAX cuBLASMp executor, the unconditional ProcessGroupNCCL.hpp include in common.h, and the silent bulk-overlap fallback in base.py — and should be addressed before merge.

Confidence Score: 3/5

Not safe to merge: several P1 issues from prior review rounds remain unfixed in the current head.

Multiple P1 bugs from earlier review rounds are still present: wrong tp_rank argument in the JAX cuBLASMp executor (get_tp_domain_id instead of get_local_device_id_within_tp_domain), unconditional ProcessGroupNCCL.hpp include in common.h breaking non-NCCL builds, and silent cuBLASMp fallback to Userbuffers for bulk overlap. New finding is P2 only. Multiple outstanding P1s pull the score below the single-P1 ceiling of 4.

transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (wrong tp_rank), transformer_engine/pytorch/csrc/common.h (unconditional ProcessGroupNCCL.hpp include), transformer_engine/pytorch/module/base.py (silent bulk overlap fallback), tests/cpp_distributed/test_comm_gemm.cu (too-loose BF16 tolerance).

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp cuBLASMp path added in get_executor(); uses get_tp_domain_id() as tp_rank (P1, unfixed from prior review); plan_id now includes use_cublasmp; stale file cleanup and pgid removal are improvements.
transformer_engine/pytorch/csrc/common.h Adds unconditional ProcessGroupNCCL.hpp include; P1 flagged in prior review – should be guarded with #ifdef NVTE_WITH_CUBLASMP.
transformer_engine/pytorch/module/base.py initialize_ub() gains with_cublasmp param; CommOverlapP2P and CommOverlap dispatch added; CommOverlapHelper now always passed intra group for single-domain case; silent bulk fallback still present (flagged in prior review).
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp Adds cuBLASMp constructors for CommOverlapCore/Base/P2PBase; uses ncclCommInitRank correctly; cublasmp_ag_gemm/gemm_rs/gemm_ar dispatch paths added. No intra NCCL comm when intra_domain_group absent, partially mitigated by callers always passing an intra group.
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp CommOverlapHelper extended with NCCL comm initialization (world + intra) behind NVTE_WITH_CUBLASMP. ncclCommInitRank used correctly.
tests/cpp_distributed/test_comm_gemm.cu RS/AR reference rewritten with proper NCCL-based scatter; AG reference uses AllGather helpers; BF16 AG tolerance relaxed from 1e-3 to 6e-1 (600×, P2 concern).
tests/pytorch/distributed/run_gemm_with_overlap.py Adds cuBLASMp code paths; reference computation refactored to local-rank comparison; torch.transpose/tuple-unpack bugs from prior review fixed; atol relaxed to 0.002.
examples/jax/collective_gemm/run_test_cgemm.sh Many previously flagged issues fixed: BACKENDS array comma removed, -c arg split, per-backend wait added, --use-cublasmp now appended correctly.

Sequence Diagram

sequenceDiagram
    participant PY as Python caller
    participant CH as CommOverlapHelper
    participant NCO as ncclComm (world/intra)
    participant COR as CommOverlapCore (cuBLASMp path)
    participant CTX as NVTECommGemmCtx

    PY->>CH: CommOverlapHelper(world_pg, intra_pg)
    CH->>NCO: ncclGetUniqueId → broadcast → ncclCommInitRank (world)
    CH->>NCO: ncclGetUniqueId → broadcast → ncclCommInitRank (intra)
    CH-->>PY: helper ready

    PY->>COR: CommOverlap(helper, tp_rank, tp_size, ...)
    COR->>CH: get_nccl_comm("intra")
    CH-->>COR: ncclComm_t
    COR->>CTX: nvte_comm_gemm_ctx_create(nccl_comm, tp_size, tp_rank)
    CTX-->>COR: _cublasmp_ctx

    PY->>COR: split_overlap_rs / atomic_gemm_overlap_ag / ...
    COR->>CTX: nvte_gemm_reduce_scatter / nvte_all_gather_gemm(stream)
    CTX-->>COR: async result
    COR-->>PY: output tensor
Loading

Reviews (20): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (8)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)

    logic: Variable shadowing bug: k is assigned k * _tp_size where k appears on both sides. Should be k = k_local * _tp_size.

  2. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)

    logic: Invalid reinterpret_cast: cannot cast an int* (pointer) to int (value). Should be reinterpret_cast<void**>(&handler._device_barrier).

  3. transformer_engine/pytorch/csrc/extensions.h, line 517 (link)

    syntax: Stray character a that will cause compilation failure.

  4. transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)

    logic: Constructor parameter mismatch: CommOverlapBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  5. transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  6. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_domain_id, ...). Should use tp_rank instead of tp_domain_id.

    Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?

  7. tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)

    style: Unconditional initialization: local_kernel2_t_shape is only used when ub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  8. tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)

    style: Incorrect initialization: ref2_g should only be assigned when ub_obj2 is not None, but it's unconditionally assigned to (0,) which doesn't match the expected tensor type.

18 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
denera and others added 6 commits December 17, 2025 02:16
…rk extensions

Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)

    logic: return value from nvte_comm_gemm_ctx_create() is not assigned to _cublasmp_ctx, causing null pointer when destructor calls nvte_comm_gemm_ctx_destroy(_cublasmp_ctx) on line 173

  2. transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)

    style: parameter order changed - collective_op moved after transpose_batch_sequence, sequence_dim, is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified?

  3. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)

    logic: _tp_size field is not initialized in the cuBlasMp constructor, but it's used in cublasmp_ag_gemm() and cublasmp_gemm_rs() methods (lines 321, 332). Add initialization

18 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/module/base.py, line 415-417 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  2. transformer_engine/pytorch/module/base.py, line 387-389 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  3. tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  4. tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  5. tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  6. tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

19 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
@denera denera force-pushed the common/tp-overlap-cublasmp branch from a25e667 to 6c6cc4d Compare March 16, 2026 17:14
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
BACKENDS=("userbuffers", "cublasmp")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bash array comma syntax creates malformed "userbuffers," element

In bash, array elements are separated by whitespace — commas are treated as literal characters, not separators. This declaration:

BACKENDS=("userbuffers", "cublasmp")

creates a two-element array where the first element is "userbuffers," (with a trailing comma) and the second is "cublasmp". Downstream, LOG_FILE and grep patterns expand to ${TEST_NAME}_gpu_0_userbuffers,.log (comma included), which may confuse any tooling that parses the filenames. The fix is to use whitespace separation:

Suggested change
BACKENDS=("userbuffers", "cublasmp")
BACKENDS=("userbuffers" "cublasmp")

Comment on lines +108 to +117
)
if [ "$BACKEND" == "cublasmp" ]; then
pytest_args+=("--use-cublasmp")
fi

pytest_args=(
"-s"
"-c $TE_PATH/tests/jax/pytest.ini"
"-vs"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--use-cublasmp flag silently discarded — cuBLASMp backend never actually tested

pytest_args+=("--use-cublasmp") on line 110 appends to pytest_args before it is initialised. Then line 113 unconditionally resets pytest_args to a fresh array, discarding the previously appended flag. The result is that pytest is always invoked without --use-cublasmp, so the "cublasmp" loop iteration runs identical Userbuffers tests to the first iteration — the cuBLASMp code path is never exercised.

The fix is to initialise pytest_args before the conditional, then append the backend-specific flag:

      pytest_args=(
        "-s"
        "-c" "$TE_PATH/tests/jax/pytest.ini"
        "-vs"
      )
      if [ "$BACKEND" == "cublasmp" ]; then
        pytest_args+=("--use-cublasmp")
      fi

denera added 2 commits March 16, 2026 17:57
Signed-off-by: Alp Dener <adener@nvidia.com>
Comment on lines +439 to +448
ub_obj = tex.CommOverlap(
helper,
local_rank,
tp_size,
num_comm_sm=num_sm,
atomic_gemm=atomic_gemm,
)
else:
ub_obj = tex.CommOverlap(
shape, # Communication buffer shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent cuBLASMp fallback for bulk overlap goes unannounced

When with_cublasmp=True is passed but the configured method is "bulk", the condition with_cublasmp and method != "bulk" is False and the code silently constructs a regular Userbuffers CommOverlap object instead. The user has explicitly opted into cuBLASMp but will unknowingly receive the Userbuffers backend for every bulk-overlap buffer — with no warning anywhere in the output.

The C++ side already documents the incompatibility (NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp")), so the Python side should at least surface this to the caller:

else:
    if with_cublasmp and method != "bulk":
        ub_obj = tex.CommOverlap(
            helper,
            local_rank,
            tp_size,
            num_comm_sm=num_sm,
            atomic_gemm=atomic_gemm,
        )
    else:
        if with_cublasmp and method == "bulk":
            warnings.warn(
                "cuBLASMp does not support bulk overlap; falling back to "
                "Userbuffers backend for bulk-overlap buffers.",
                stacklevel=2,
            )
        ub_obj = tex.CommOverlap(
            shape,
            ...
        )

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from ebe7679 to 5a8c7ae Compare March 17, 2026 20:02
@denera denera added 2.15.0 and removed 2.14.0 labels Mar 23, 2026
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
…fer requirement

Fix several issues preventing delayed-scaling FP8 collective GEMM tests
from passing with the cuBLASMp backend:

- Clean up stale NCCL unique ID files between test runs using a
  sync_global_devices barrier so crashed runs don't poison subsequent ones
- Use NumPy instead of JAX ops in process-0-only result checks to avoid
  multi-process XLA compilation deadlocks
- Expose nvte_built_with_cublasmp() to Python and add runtime skip logic
  in conftest.py and run_test_cgemm.sh
- Add cuBLASMp RS output path in gemm.cpp (cuBLASMp writes reduce-scattered
  result directly into D, unlike Userbuffers which uses an intermediate ubuf)

Also document on gemm() and collective_gemm_bootstrap() that XLA command
buffers must be disabled when using collective GEMM with communication
overlap, since both Userbuffers and cuBLASMp use internal CUDA streams
for NCCL collectives that break CUDA graph capture.

Signed-off-by: adener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from ef09ae3 to caa741e Compare May 1, 2026 18:35
if aggregate:
test_cmd.append("--aggregate")
if use_cublasmp:
if quantization == "mxfp8":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@denera cublasMp does support MXFP8. Have you faced any issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants