diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..7985df508c 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -8,7 +8,13 @@ import setuptools -from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled +from .utils import ( + all_files_in_dir, + common_lib_has_symbol, + cuda_version, + get_cuda_include_dirs, + debug_build_enabled, +) from typing import List @@ -87,6 +93,13 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + # Auto-detect cuSolverMp support from the already-built common library. + # The PyTorch extension only calls nvte_* C API functions (not cusolverMp + # directly), so it only needs the compile definition, not cusolverMp + # headers or libraries. + if common_lib_has_symbol("nvte_cusolvermp_ctx_create"): + cxx_flags.append("-DNVTE_WITH_CUSOLVERMP") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/build_tools/utils.py b/build_tools/utils.py index 885901068a..3fd537db3d 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -374,3 +374,50 @@ def copy_common_headers( new_path = dst_dir / path.relative_to(src_dir) new_path.parent.mkdir(exist_ok=True, parents=True) shutil.copy(path, new_path) + + +def common_lib_has_symbol(symbol: str) -> bool: + """Check if the built libtransformer_engine.so exports a given symbol. + + Searches for the library in known build/install locations and uses + ``nm -D --defined-only`` to inspect the dynamic symbol table. + """ + root = Path(__file__).resolve().parent.parent + + # Candidate paths: editable-install root, default CMake build dir, + # and user-specified CMake build dir. + candidates = [ + root / "libtransformer_engine.so", + root / "build" / "cmake" / "libtransformer_engine.so", + ] + custom_build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR") + if custom_build_dir: + candidates.append(Path(custom_build_dir) / "libtransformer_engine.so") + + lib_path = None + for candidate in candidates: + if candidate.is_file(): + lib_path = candidate + break + + if lib_path is None: + raise FileNotFoundError( + "Could not find libtransformer_engine.so in any of the expected locations: " + + ", ".join(str(c) for c in candidates) + ) + + try: + result = subprocess.run( + ["nm", "-D", "--defined-only", str(lib_path)], + capture_output=True, + text=True, + check=True, + ) + except FileNotFoundError as e: + raise RuntimeError("'nm' is not available on this system.") from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"'nm' failed on {lib_path} (exit code {e.returncode}):\n{e.stderr}" + ) from e + + return symbol in result.stdout diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..db13e9f1e0 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" # debug tests diff --git a/setup.py b/setup.py index 18bb736f24..708e9ff36e 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,11 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") print("CMAKE_FLAGS:", cmake_flags[-2:]) + if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUSOLVERMP=ON") + cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") + cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py new file mode 100644 index 0000000000..96663bda43 --- /dev/null +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz test worker. + +Launched via torchrun from test_newton_schulz.py. +""" + +import argparse +import sys + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + + +@record +def main(): + parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) + parser.add_argument("--matrix-size", type=int, default=256) + parser.add_argument("--num-iterations", type=int, default=5) + parser.add_argument("--atol", type=float, default=1e-2) + parser.add_argument("--rtol", type=float, default=1e-2) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 + N = args.matrix_size + + # Ensure N is divisible by world_size + assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" + + # Create a random symmetric positive definite matrix on rank 0 + # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) + # This ensures Newton-Schulz converges + if rank == 0: + torch.manual_seed(42) + Q, _ = torch.linalg.qr(torch.randn(N, N, device="cuda", dtype=torch.float32)) + eigenvalues = torch.rand(N, device="cuda", dtype=torch.float32) * 0.8 + 0.1 + A = Q @ torch.diag(eigenvalues) @ Q.T + A = A.to(dtype) + else: + A = torch.empty(N, N, device="cuda", dtype=dtype) + + # Broadcast the full matrix to all ranks + dist.broadcast(A, src=0) + + # Scatter rows to each rank + local_rows = N // world_size + x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous() + + # Run the distributed Newton-Schulz + from transformer_engine.pytorch.newton_schulz import newton_schulz + + group = dist.group.WORLD + newton_schulz(x_local, group, args.num_iterations) + + # Gather results + gathered = [torch.empty_like(x_local) for _ in range(world_size)] + dist.all_gather(gathered, x_local) + X = torch.cat(gathered, dim=0) + + # Check: the resulting matrix should be orthogonal + if rank == 0: + XXT = X @ X.t() + I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) + max_diff = (XXT - I).abs().max().item() + print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) + + if torch.allclose(XXT, I, atol=args.atol, rtol=args.rtol): + print("NUMERICAL CHECK PASSED", flush=True) + else: + print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) + sys.exit(1) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py new file mode 100644 index 0000000000..5646d11dd9 --- /dev/null +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed Newton-Schulz matrix orthogonalization.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("matrix_size", [256]) +@pytest.mark.parametrize("num_iterations", [5, 15]) +def test_newton_schulz(dtype, matrix_size, num_iterations): + """Test distributed Newton-Schulz matrix orthogonalization.""" + test_path = TEST_ROOT / "run_newton_schulz.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--dtype={dtype}", + f"--matrix-size={matrix_size}", + f"--num-iterations={num_iterations}", + ] + if dtype == "bfloat16": + test_cmd += ["--atol=5e-2", "--rtol=5e-2"] + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError( + "Newton-Schulz test failed.\n" + f"stdout: {result.stdout.decode()}\n" + f"stderr: {result.stderr.decode()}" + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index efe958f844..0edcf33e14 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -227,6 +227,11 @@ list(APPEND transformer_engine_SOURCES comm_gemm/comm_gemm.cpp) endif() +if (NVTE_WITH_CUSOLVERMP) +list(APPEND transformer_engine_SOURCES + newton_schulz/newton_schulz.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -300,6 +305,19 @@ if (NVTE_WITH_CUBLASMP) message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() +option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF) +if (NVTE_WITH_CUSOLVERMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP) + target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include) + find_library(CUSOLVERMP_LIB + NAMES cusolverMp libcusolverMp + PATHS ${CUSOLVERMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB}) + message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h new file mode 100644 index 0000000000..b604b2db39 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file newton_schulz.h + * \brief Functions for distributed Newton-Schulz matrix orthogonalization. + * + * This API is a TE-native binding to the cuSolverMp library. + * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ +#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; + +/*! \brief Create a cuSolverMp context for Newton-Schulz operations. + * + * Creates a dedicated CUDA stream internally (cuSolverMp requires a + * non-default stream). + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a cuSolverMp context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); + +/*! \brief Compute Newton-Schulz matrix orthogonalization in-place. + * + * \param[in] ctx cuSolverMp context. + * \param[in] m Global number of rows. + * \param[in] n Global number of columns. + * \param[in,out] x Local part of the matrix (modified in-place). + * \param[in] num_iterations Number of Newton-Schulz iterations. + * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial + * degree used internally by cuSolverMp). + * \param[in] num_coefficients Number of elements in the coefficients array. + * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. + * Used for event-based synchronisation with the internal stream. + */ +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp new file mode 100644 index 0000000000..bf9a92aac8 --- /dev/null +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -0,0 +1,194 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/newton_schulz.h" + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +// RAII wrapper types for cuSolverMp handles (outside anonymous namespace because +// CusolverMpHandle and CusolverMpGrid are used in the NVTECusolverMpCtx struct) + +struct CusolverMpHandleDeleter { + void operator()(cusolverMpHandle_t handle) const { cusolverMpDestroy(handle); } +}; +using CusolverMpHandle = + std::unique_ptr, CusolverMpHandleDeleter>; + +struct CusolverMpGridDeleter { + void operator()(cusolverMpGrid_t grid) const { cusolverMpDestroyGrid(grid); } +}; +using CusolverMpGrid = + std::unique_ptr, CusolverMpGridDeleter>; + +namespace { + +struct CusolverMpMatrixDescDeleter { + void operator()(cusolverMpMatrixDescriptor_t desc) const { cusolverMpDestroyMatrixDesc(desc); } +}; +using CusolverMpMatrixDesc = std::unique_ptr, + CusolverMpMatrixDescDeleter>; + +struct CusolverMpNSDescDeleter { + void operator()(cusolverMpNewtonSchulzDescriptor_t desc) const { + cusolverMpNewtonSchulzDescriptorDestroy(desc); + } +}; +using CusolverMpNSDesc = std::unique_ptr, + CusolverMpNSDescDeleter>; + +CusolverMpHandle MakeCusolverMpHandle(int device_id, cudaStream_t stream) { + cusolverMpHandle_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreate(&raw, device_id, stream)); + return CusolverMpHandle(raw); +} + +CusolverMpGrid MakeCusolverMpGrid(cusolverMpHandle_t handle, ncclComm_t comm, int32_t nprow, + int32_t npcol, cusolverMpGridMapping_t mapping) { + cusolverMpGrid_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreateDeviceGrid(handle, &raw, comm, nprow, npcol, mapping)); + return CusolverMpGrid(raw); +} + +CusolverMpMatrixDesc MakeCusolverMpMatrixDesc(cusolverMpGrid_t grid, cudaDataType_t dtype, + int64_t m, int64_t n, int64_t mb, int64_t nb, + uint32_t rsrc, uint32_t csrc, int64_t lld) { + cusolverMpMatrixDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP( + cusolverMpCreateMatrixDesc(&raw, grid, dtype, m, n, mb, nb, rsrc, csrc, lld)); + return CusolverMpMatrixDesc(raw); +} + +CusolverMpNSDesc MakeCusolverMpNSDesc() { + cusolverMpNewtonSchulzDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulzDescriptorCreate(&raw)); + return CusolverMpNSDesc(raw); +} + +} // namespace + +struct NVTECusolverMpCtx { + int64_t nranks; + int64_t rank; + cudaStream_t stream; + cudaEvent_t in_ready; + cudaEvent_t out_ready; + CusolverMpHandle handle; + CusolverMpGrid grid; + void* workspace; + size_t workspace_size; +}; + +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_cusolvermp_ctx_create); + int device_id{}; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + cudaEvent_t in_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&in_ready)); + cudaEvent_t out_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&out_ready)); + + auto handle = MakeCusolverMpHandle(device_id, stream); + auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, CUSOLVERMP_GRID_MAPPING_COL_MAJOR); + + return new NVTECusolverMpCtx{ + .nranks = nranks, + .rank = rank, + .stream = stream, + .in_ready = in_ready, + .out_ready = out_ready, + .handle = std::move(handle), + .grid = std::move(grid), + .workspace = nullptr, + .workspace_size = 0, + }; +} + +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { + NVTE_API_CALL(nvte_cusolvermp_ctx_destroy); + if (ctx->workspace) { + cudaFree(ctx->workspace); + } + // Destroy handle and grid before the stream they depend on + ctx->handle.reset(); + ctx->grid.reset(); + cudaEventDestroy(ctx->in_ready); + cudaEventDestroy(ctx->out_ready); + cudaStreamDestroy(ctx->stream); + delete ctx; +} + +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream) { + NVTE_API_CALL(nvte_newton_schulz); + NVTE_CHECK(num_coefficients == num_iterations * 3, num_iterations, " iterations require ", + num_iterations * 3, " coefficients, but ", num_coefficients, " are passed"); + const auto* t = convertNVTETensorCheck(x); + + // Make the internal stream wait for the caller's stream so that + // the input tensor is ready before cuSolverMp reads it. + NVTE_CHECK_CUDA(cudaEventRecord(ctx->in_ready, caller_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, ctx->in_ready)); + + // Block size for ScaLAPACK-style distribution + const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; + const int64_t nb = n; + + // Compute local leading dimension + const int64_t local_rows = cusolverMpNUMROC(m, mb, ctx->rank, 0, ctx->nranks); + const int64_t lld = std::max(local_rows, static_cast(1)); + + const cudaDataType_t cuda_dtype = get_cuda_dtype(t->dtype()); + + // Create matrix descriptor + auto mat_desc = MakeCusolverMpMatrixDesc(ctx->grid.get(), cuda_dtype, m, n, mb, nb, 0, 0, lld); + + // Create Newton-Schulz descriptor + auto ns_desc = MakeCusolverMpNSDesc(); + + // Query workspace sizes + size_t wrksp_size_device = 0; + size_t wrksp_size_host = 0; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz_bufferSize( + ctx->handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, &wrksp_size_device, &wrksp_size_host)); + + // Allocate/grow device workspace + if (ctx->workspace_size < wrksp_size_device) { + if (ctx->workspace) { + NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); + } + NVTE_CHECK_CUDA(cudaMalloc(&ctx->workspace, wrksp_size_device)); + ctx->workspace_size = wrksp_size_device; + } + + // Allocate host workspace + std::vector workspace_host(wrksp_size_host); + + // Execute Newton-Schulz + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( + ctx->handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), + workspace_host.size(), nullptr)); + + // Make the caller's stream wait for the internal stream so that + // the output tensor is ready before the caller uses it. + NVTE_CHECK_CUDA(cudaEventRecord(ctx->out_ready, ctx->stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, ctx->out_ready)); +} diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c542afa393..cf72cab048 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,6 +18,10 @@ #include #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP +#include +#endif // NVTE_WITH_CUSOLVERMP + #include #include #include @@ -106,6 +110,18 @@ #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP + +#define NVTE_CHECK_CUSOLVERMP(expr) \ + do { \ + const cusolverStatus_t status_NVTE_CHECK_CUSOLVERMP = (expr); \ + if (status_NVTE_CHECK_CUSOLVERMP != CUSOLVER_STATUS_SUCCESS) { \ + NVTE_ERROR("cuSolverMp Error: ", std::to_string(status_NVTE_CHECK_CUSOLVERMP)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUSOLVERMP + #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..ec9caa245f 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -59,6 +59,7 @@ from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.newton_schulz import newton_schulz from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.quantized_tensor import Quantizer diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..27585bf0c7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -504,6 +504,21 @@ void nvshmem_finalize(); void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream); +/*************************************************************************************************** + * Newton-Schulz (cuSolverMp) + **************************************************************************************************/ + +#ifdef NVTE_WITH_CUSOLVERMP + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); + +void cusolvermp_ctx_destroy(int64_t ctx_ptr); + +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients); + +#endif // NVTE_WITH_CUSOLVERMP + } // namespace transformer_engine::pytorch /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp new file mode 100644 index 0000000000..3615d5be31 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -0,0 +1,44 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_CUSOLVERMP + +#include "transformer_engine/newton_schulz.h" + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { + auto comm = reinterpret_cast(nccl_comm_ptr); + auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); + return reinterpret_cast(ctx); +} + +void cusolvermp_ctx_destroy(int64_t ctx_ptr) { + auto* ctx = reinterpret_cast(ctx_ptr); + nvte_cusolvermp_ctx_destroy(ctx); +} + +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients) { + auto* ctx = reinterpret_cast(ctx_ptr); + + // Build NVTETensor from PyTorch tensor + auto x_sizes = x.sizes().vec(); + std::vector shape(x_sizes.begin(), x_sizes.end()); + + auto te_dtype = GetTransformerEngineDType(x.scalar_type()); + TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); + + auto caller_stream = at::cuda::getCurrentCUDAStream().stream(); + nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), + static_cast(coefficients.size()), caller_stream); +} + +} // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_CUSOLVERMP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..71fecf5651 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -443,6 +443,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, "Fused compute E8M0 scale_inv from amax", py::call_guard()); +#ifdef NVTE_WITH_CUSOLVERMP + // Newton-Schulz (cuSolverMp) + m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create, + "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), + py::arg("rank"), py::call_guard()); + m.def("cusolvermp_ctx_destroy", &transformer_engine::pytorch::cusolvermp_ctx_destroy, + "Destroy cuSolverMp context", py::arg("ctx_ptr"), py::call_guard()); + m.def("newton_schulz", &transformer_engine::pytorch::newton_schulz, + "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), + py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), + py::call_guard()); +#endif // NVTE_WITH_CUSOLVERMP + // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py new file mode 100644 index 0000000000..3e4c844bdc --- /dev/null +++ b/transformer_engine/pytorch/newton_schulz.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz inverse square root via cuSolverMp.""" + +from typing import List, Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: + """Extract the raw NCCL communicator pointer from a PyTorch process group.""" + backend = dist.get_backend(group) + if backend != "nccl": + raise RuntimeError(f"newton_schulz requires NCCL backend, got '{backend}'") + nccl_backend = group._get_backend(torch.device("cuda")) + return nccl_backend._comm_ptr() + + +def newton_schulz( + x: torch.Tensor, + group: dist.ProcessGroup, + num_iterations: int = 5, + coefficients: Optional[List[float]] = None, +) -> None: + """Compute Newton-Schulz inverse square root in-place on a distributed matrix. + + Parameters + ---------- + x : torch.Tensor + Local part of the distributed matrix (modified in-place). + Must be a 2D CUDA tensor of type float32 or bfloat16. + group : torch.distributed.ProcessGroup + Process group with NCCL backend for distributed communication. + num_iterations : int, optional + Number of Newton-Schulz iterations. Default: 5. + coefficients : list of float, optional + Polynomial coefficients for the Newton-Schulz iteration. + """ + QUINTIC_COEFFICIENTS = [ + 4.0848, + -6.8946, + 2.9270, + 3.9505, + -6.3029, + 2.6377, + 3.7418, + -5.5913, + 2.3037, + 2.8769, + -3.1427, + 1.2046, + 2.8366, + -3.0525, + 1.2012, + ] + if coefficients is None: + coefficients = ( + QUINTIC_COEFFICIENTS if num_iterations == 5 else [1.5, -0.5, 0.0] * num_iterations + ) + assert ( + len(coefficients) == num_iterations * 3 + ), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" + + if x.dim() != 2: + raise ValueError(f"Expected 2D tensor, got {x.dim()}D") + if not x.is_cuda: + raise ValueError("Input tensor must be on CUDA device") + + if not hasattr(tex, "newton_schulz"): + raise RuntimeError( + "newton_schulz requires Transformer Engine to be built with NVTE_WITH_CUSOLVERMP=1" + ) + + nccl_comm_ptr = _get_nccl_comm_ptr(group) + nranks = dist.get_world_size(group) + rank = dist.get_rank(group) + + # Global matrix dimensions + m = x.size(0) * nranks # rows are distributed across ranks + n = x.size(1) + + ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank) + try: + tex.newton_schulz(ctx_ptr, m, n, x, num_iterations, coefficients) + finally: + tex.cusolvermp_ctx_destroy(ctx_ptr)