Skip to content

Pytorch binding for cublas gemm + Grouped Linear integration#2669

Draft
vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/pytorch_binding_for_cublas_gemm
Draft

Pytorch binding for cublas gemm + Grouped Linear integration#2669
vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
vthumbe1503:users/vthumbe/pytorch_binding_for_cublas_gemm

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ksivaman and others added 2 commits February 6, 2026 06:10
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Users/vthumbe/pytorch binding for cublas gemm Pytorch binding for cublas gemm + Grouped Linear integration Feb 10, 2026
vthumbe1503 and others added 4 commits February 11, 2026 03:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 requested a review from ptrendx February 11, 2026 17:15
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review March 6, 2026 17:46
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR introduces a PyTorch binding for the new nvte_grouped_gemm cuBLAS 13.2+ API (Blackwell SM100+ only) and integrates it into GroupedLinear via a new single_weight parameter that stores all expert weights as a single contiguous tensor, enabling the cuBLAS grouped GEMM path when m_splits is passed as a device tensor.

Key changes:

  • cublaslt_grouped_gemm.cu: New nvte_grouped_gemm implementation using cublasLtGroupedMatrixLayout; handles FP8/BF16/FP16 inputs, per-tensor alpha/beta, rowwise/columnwise operand selection, and a device-side setup kernel that computes per-tensor pointers and storage dimensions.
  • gemm.py / gemm.cpp / pybind.cpp: New general_grouped_gemm_for_grouped_tensor Python binding that wraps the cuBLAS path; the workspace size helper get_grouped_gemm_setup_workspace_size uses 8 * aligned_ptr_size while the C++ required_setup_size uses 6 * ptr_size, so the Python side over-allocates (safe at runtime but contradicts its own "Must match" comment).
  • grouped_linear.py: single_weight=True stores all GEMM weights as one flat weight0 parameter. The make_grouped_weights method's if self.single_weight: block has no early return, causing an AttributeError on weight1, weight2, … when single_weight=True and single_grouped_parameter=True are both set (non-FP8 path).
  • Tests: C++ tests updated with larger shapes and a required cudaDeviceSynchronize. test_grouped_linear_accuracy_cutlass sets NVTE_USE_CUTLASS_GROUPED_GEMM=1 without cleanup, risking test pollution for subsequent tests in the same process.
  • type_converters.cpp: New GroupedTensorFromPyTorchGroupedTensor correctly marshals all GroupedTensor metadata fields to the C++ wrapper.

Confidence Score: 3/5

  • Not safe to merge without fixing the make_grouped_weights crash and the workspace size formula mismatch.
  • Two logic bugs were found: (1) make_grouped_weights raises AttributeError when single_weight=True and single_grouped_parameter=True due to missing early return; (2) get_grouped_gemm_setup_workspace_size computes 8 * aligned_ptr_size instead of 6 * ptr_size, contradicting its own documentation comment. While the over-allocation is not a crash today, it is technically incorrect and will silently diverge from the C++ layout if the workspace structure ever changes. The core cuBLAS kernel, C++ binding, and type converter are well-structured.
  • transformer_engine/pytorch/cpp_extensions/gemm.py (workspace size formula) and transformer_engine/pytorch/module/grouped_linear.py (missing return in make_grouped_weights).

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds general_grouped_gemm_for_grouped_tensor Python binding for cuBLAS 13.2+ grouped GEMM; workspace size formula uses 8 * aligned_ptr_size instead of 6 * ptr_size as in C++, contradicting its own "Must match" comment.
transformer_engine/pytorch/module/grouped_linear.py Adds single_weight parameter to GroupedLinear for contiguous weight storage; make_grouped_weights falls through its if self.single_weight: block without returning, crashing with AttributeError when single_weight=True and single_grouped_parameter=True in non-FP8 mode.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu New cuBLAS 13.2+ / Blackwell (SM100+) grouped GEMM kernel using cublasLtGroupedMatrixLayout; setup workspace layout, operand selection, FP8 scale pointer handling, and algorithm heuristics look correct.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds te_general_grouped_gemm_for_grouped_tensor C++ entry point that delegates to nvte_grouped_gemm; tensor conversion, workspace setup, and optional config path all look correct.
transformer_engine/pytorch/csrc/type_converters.cpp Adds GroupedTensorFromPyTorchGroupedTensor converter that correctly maps all GroupedTensor fields (rowwise/columnwise data, scale, amax, first_dims, last_dims, tensor_offsets) to the C++ wrapper.
tests/pytorch/test_numerics.py Adds test_grouped_linear_m_splits_tensor covering the new single/multi-weight cuBLAS path and forward/backward pass; NVTE_USE_CUTLASS_GROUPED_GEMM=1 in test_grouped_linear_accuracy_cutlass is set but never cleaned up, risking test pollution.
tests/cpp/operator/test_grouped_gemm.cu Updates C++ grouped GEMM tests with larger shapes and adds cudaDeviceSynchronize before result comparison; removes the unsupported transa=true FP8 test case.
tests/pytorch/test_grouped_tensor.py Comprehensive test suite for GroupedTensor construction, quantization, CUDA graph capturability, and varying shapes; changes look correct and thorough.

Sequence Diagram

sequenceDiagram
    participant PY as GroupedLinear (Python)
    participant GemmPY as gemm.py
    participant CPP as gemm.cpp (C++)
    participant TC as type_converters.cpp
    participant CUDA as cublaslt_grouped_gemm.cu

    PY->>GemmPY: general_grouped_gemm_for_grouped_tensor(A, B, out, m_splits_tensor)
    GemmPY->>GemmPY: allocate workspace_setup (get_grouped_gemm_setup_workspace_size)
    GemmPY->>GemmPY: allocate workspace_cublas (32 MiB)
    GemmPY->>CPP: tex.te_general_grouped_gemm_for_grouped_tensor(A, B, C=out, D=out, alpha, beta, ...)
    CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(A)
    TC-->>CPP: GroupedTensorWrapper (rowwise/columnwise data, shape metadata)
    CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(B)
    TC-->>CPP: GroupedTensorWrapper
    CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(D/C)
    TC-->>CPP: GroupedTensorWrapper
    CPP->>CUDA: nvte_grouped_gemm(A, B, C, D, alpha, beta, ws_setup, ws_cublas, config, stream)
    CUDA->>CUDA: validate_grouped_gemm_inputs()
    CUDA->>CUDA: select_grouped_operand() — rowwise vs columnwise, FP8 TN-only logic
    CUDA->>CUDA: launch setup_grouped_gemm_kernel() — fills A/B/C/D pointer arrays + dimensions
    CUDA->>CUDA: cudaDeviceSynchronize (implicit via stream ordering)
    CUDA->>CUDA: init_matrix_layouts() + init_matmul_desc() + set_fp8_scale_pointers()
    CUDA->>CUDA: select_grouped_gemm_algo() (cuBLASLt heuristics)
    CUDA->>CUDA: cublasLtMatmul() — batched grouped GEMM
    CUDA-->>CPP: result in D (GroupedTensor data buffer)
    CPP-->>GemmPY: D (same Python GroupedTensor object)
    GemmPY-->>PY: output GroupedTensor
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/module/grouped_linear.py, line 940-978 (link)

    make_grouped_weights crashes with AttributeError when single_weight=True and single_grouped_parameter=True

    When single_weight=True, only weight0 is registered as a nn.Parameter (num_tensors = 1 in __init__). make_grouped_weights is called by reset_parameters when single_grouped_parameter=True.

    Inside make_grouped_weights, the if self.single_weight: block (lines 940–950) sets grouped_weight_storage correctly, but there is no return or else guard — execution falls through unconditionally to line 951 and beyond.

    Because single_weight=True disallows FP8 primary weights (enforced earlier in __init__), weight_quantizers[0] is always None, making recipe evaluate to None and the early-return at line 957 never fires. Execution then reaches:

    weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]

    For num_gemms > 1, this raises AttributeError on weight1, weight2, …, because those attributes were never registered.

    Fix: add an early return after the single_weight block is done, or guard the rest of the function with else:

    if self.single_weight:
        weight = getattr(self, "weight0")
        ...
        self.grouped_weight_storage = GroupedTensor(...)
        return   # ← prevent fall-through to the multi-weight path
  2. tests/pytorch/test_numerics.py, line 2095 (link)

    Environment variable set but never reset — test pollution risk

    os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" is set unconditionally at the start of the test body and never cleaned up. If this test is parametrized (it is) or if another test runs in the same process after it, the environment variable will still be "1" and could silently alter their behaviour.

    Use a try/finally or a pytest monkeypatch/fixture to guarantee cleanup:

    def test_grouped_linear_accuracy_cutlass(...):
        os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
        try:
            test_grouped_linear_accuracy(...)
        finally:
            os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)

Last reviewed commit: 38cf811

Comment on lines 1127 to 1128
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError in backward_dw when single_weight=True

In backward_dw, weight_params is constructed using range(self.num_gemms), but when single_weight=True, only weight0 is registered as a parameter (self.num_weight_params = 1). Accessing weight1, weight2, etc. will raise AttributeError.

weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]  # BUG: should use num_weight_params
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]       # BUG: same issue

This bug surfaces whenever single_weight=True AND delay_wgrad_compute=True. The fix should use self.num_weight_params:

Suggested change
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_weight_params)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_weight_params)]

Comment on lines +799 to +801
if single_weight:
bias = False
return_bias = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent bias override with no warning

When single_weight=True, bias and return_bias are silently overridden to False without any warning to the caller. A user who passes bias=True, single_weight=True will get no bias but also no error message or warning. Add a warnings.warn(...) call to notify callers of the implicit override:

Suggested change
if single_weight:
bias = False
return_bias = False
if single_weight:
if bias:
warnings.warn(
"bias=True is not supported with single_weight=True; bias has been disabled.",
UserWarning,
)
bias = False
return_bias = False

vthumbe1503 and others added 2 commits March 6, 2026 18:15
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft March 6, 2026 18:42
Comment on lines +296 to +308
def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int:
"""Return workspace size for grouped GEMM pointer setup.
Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu.
"""
ptr_bytes = ctypes.sizeof(ctypes.c_void_p)
int_bytes = ctypes.sizeof(ctypes.c_int)
ptr_size = num_tensors * ptr_bytes
int_size = num_tensors * int_bytes
k_ptr_alignment = 16
aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment
size = 8 * aligned_ptr_size + 6 * int_size
alignment = 256
return ((size + alignment - 1) // alignment) * alignment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workspace size formula doesn't match the C++ implementation

The comment explicitly states this must match GroupedGemmSetupWorkspace::required_setup_size, but the two formulas disagree in both the number of pointer arrays and the intermediate alignment.

C++ (cublaslt_grouped_gemm.cu):

// Layout: 6 ptr arrays, then 6 int arrays
size_t size = 6 * ptr_size + 6 * int_size;

Python (here):

size = 8 * aligned_ptr_size + 6 * int_size  # 8 × (16-byte-aligned ptr arrays)

The workspace has exactly 6 pointer arrays (A_ptrs, B_ptrs, C_ptrs, D_ptrs, alpha_ptrs, beta_ptrs) as confirmed by from_buffers(), so the coefficient should be 6, not 8. The extra per-array 16-byte alignment (k_ptr_alignment) also has no counterpart in the C++ code.

Because Python allocates more than C++ requires, validate_and_get_workspace_ptr (which checks provided_size >= required_size) will always pass and there is no buffer overflow. However the over-allocation grows proportionally with num_tensors and the comment is factually wrong. The formula should be corrected to match C++:

def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int:
    ptr_bytes = ctypes.sizeof(ctypes.c_void_p)
    int_bytes = ctypes.sizeof(ctypes.c_int)
    ptr_size = num_tensors * ptr_bytes
    int_size = num_tensors * int_bytes
    # Layout: 6 ptr arrays, then 6 int arrays — must match
    # GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu
    size = 6 * ptr_size + 6 * int_size
    alignment = 256
    return ((size + alignment - 1) // alignment) * alignment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants