From 948037e3dfffc908a2778845cb998f067fab8459 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Sun, 8 Feb 2026 22:38:04 +0000 Subject: [PATCH 01/30] [Common] Add Newton-Schulz inverse square root C API via cuSolverMp Add a new distributed Newton-Schulz inverse square root API to Transformer Engine's common C library. This wraps the cusolverMpNewtonSchulz library function, following the same pattern as the existing cuBLASMp integration for comm_gemm. New files: - newton_schulz.h: Public C API header with context management and computation functions - newton_schulz/newton_schulz.cpp: Implementation with RAII wrappers for cuSolverMp handles Build integration: - New NVTE_WITH_CUSOLVERMP CMake option and CUSOLVERMP_HOME env var - NVTE_CHECK_CUSOLVERMP error checking macro in logging.h - Conditional compilation guarded by NVTE_WITH_CUSOLVERMP Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- setup.py | 5 + transformer_engine/common/CMakeLists.txt | 18 ++ .../transformer_engine/newton_schulz.h | 68 +++++++ .../common/newton_schulz/newton_schulz.cpp | 183 ++++++++++++++++++ transformer_engine/common/util/logging.h | 17 ++ 5 files changed, 291 insertions(+) create mode 100644 transformer_engine/common/include/transformer_engine/newton_schulz.h create mode 100644 transformer_engine/common/newton_schulz/newton_schulz.cpp diff --git a/setup.py b/setup.py index 18bb736f24..708e9ff36e 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,11 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") print("CMAKE_FLAGS:", cmake_flags[-2:]) + if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUSOLVERMP=ON") + cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") + cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index efe958f844..0edcf33e14 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -227,6 +227,11 @@ list(APPEND transformer_engine_SOURCES comm_gemm/comm_gemm.cpp) endif() +if (NVTE_WITH_CUSOLVERMP) +list(APPEND transformer_engine_SOURCES + newton_schulz/newton_schulz.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -300,6 +305,19 @@ if (NVTE_WITH_CUBLASMP) message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() +option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF) +if (NVTE_WITH_CUSOLVERMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP) + target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include) + find_library(CUSOLVERMP_LIB + NAMES cusolverMp libcusolverMp + PATHS ${CUSOLVERMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB}) + message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h new file mode 100644 index 0000000000..b540b9db54 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file newton_schulz.h + * \brief Functions for distributed Newton-Schulz inverse square root. + * + * This API is a TE-native binding to the cuSolverMp library. + * It computes an iterative Newton-Schulz inverse square root + * approximation on a distributed matrix. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ +#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; + +/*! \brief Create a cuSolverMp context for Newton-Schulz operations. + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a cuSolverMp context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); + +/*! \brief Compute Newton-Schulz inverse square root in-place. + * + * Performs iterative Newton-Schulz approximation of the inverse square root + * on a distributed matrix using cuSolverMp. + * + * \param[in] ctx cuSolverMp context. + * \param[in] m Global number of rows. + * \param[in] n Global number of columns. + * \param[in,out] x Local part of the matrix (modified in-place). + * \param[in] num_iterations Number of Newton-Schulz iterations. + * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial + * degree used internally by cuSolverMp). + * \param[in] num_coefficients Number of elements in the coefficients array. + * \param[in] stream CUDA stream. + */ +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, + int64_t num_coefficients, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp new file mode 100644 index 0000000000..f70986a122 --- /dev/null +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -0,0 +1,183 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/newton_schulz.h" + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +namespace { + +template +auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); + return std::unique_ptr(raw, destroy_fn); +} + +using CudaStream = + std::unique_ptr, decltype(&cudaStreamDestroy)>; + +CudaStream CudaStreamCreate() { + return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); +} + +template +auto CreateWithCusolverMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + if constexpr (raw_last) { + NVTE_CHECK_CUSOLVERMP(create_fn(std::forward(args)..., &raw)); + } else { + NVTE_CHECK_CUSOLVERMP(create_fn(&raw, std::forward(args)...)); + } + return std::unique_ptr(raw, destroy_fn); +} + +using CusolverMp = + std::unique_ptr, decltype(&cusolverMpDestroy)>; + +CusolverMp CusolverMpCreate(cudaStream_t stream) { + return CreateWithCusolverMpCheck(cusolverMpCreate, cusolverMpDestroy, + stream); +} + +using CusolverMpGrid = + std::unique_ptr, decltype(&cusolverMpDestroyGrid)>; + +CusolverMpGrid CusolverMpGridCreate(int64_t nprow, int64_t npcol, + cusolverMpGridLayout_t layout, ncclComm_t comm) { + return CreateWithCusolverMpCheck( + cusolverMpCreateDeviceGrid, cusolverMpDestroyGrid, nprow, npcol, layout, comm); +} + +using CusolverMpMatrixDesc = + std::unique_ptr, + decltype(&cusolverMpDestroyMatrixDesc)>; + +CusolverMpMatrixDesc CusolverMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, + int64_t rsrc, int64_t csrc, int64_t lld, + cudaDataType_t type, cusolverMpGrid_t grid) { + return CreateWithCusolverMpCheck( + cusolverMpCreateMatrixDesc, cusolverMpDestroyMatrixDesc, m, n, mb, nb, rsrc, csrc, lld, type, + grid); +} + +using CusolverMpNSDesc = + std::unique_ptr, + decltype(&cusolverMpNewtonSchulzDescriptorDestroy)>; + +CusolverMpNSDesc CusolverMpNSDescCreate(int64_t num_iterations, const float* coefficients, + int64_t num_coefficients) { + return CreateWithCusolverMpCheck( + cusolverMpNewtonSchulzDescriptorCreate, cusolverMpNewtonSchulzDescriptorDestroy, + num_iterations, coefficients, num_coefficients); +} + +} // namespace + +struct NVTECusolverMpCtx { + int64_t nranks; + int64_t rank; + ncclComm_t comm; + CudaStream stream; + CusolverMp cusolver_mp; + CusolverMpGrid grid; + void* workspace; + size_t workspace_size; +}; + +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_cusolvermp_ctx_create); + auto stream = CudaStreamCreate(); + auto cusolver_mp = CusolverMpCreate(stream.get()); + + // 1D row partition: nranks x 1, column-major + auto grid = + CusolverMpGridCreate(nranks, 1, CUSOLVERMP_GRID_LAYOUT_COL_MAJOR, comm); + + return new NVTECusolverMpCtx{ + .nranks = nranks, + .rank = rank, + .comm = comm, + .stream = std::move(stream), + .cusolver_mp = std::move(cusolver_mp), + .grid = std::move(grid), + .workspace = nullptr, + .workspace_size = 0, + }; +} + +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { + NVTE_API_CALL(nvte_cusolvermp_ctx_destroy); + if (ctx->workspace) { + NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); + } + delete ctx; +} + +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, + int64_t num_coefficients, cudaStream_t stream) { + NVTE_API_CALL(nvte_newton_schulz); + const auto* t = convertNVTETensorCheck(x); + + // Block size for ScaLAPACK-style distribution + const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; + const int64_t nb = n; + + // Compute local leading dimension + const int64_t local_rows = cusolverMpNUMROC(m, mb, ctx->rank, 0, ctx->nranks); + const int64_t lld = std::max(local_rows, static_cast(1)); + + const cudaDataType_t cuda_dtype = get_cuda_dtype(t->dtype()); + + // Create matrix descriptor + auto mat_desc = CusolverMpMatrixDescCreate(m, n, mb, nb, 0, 0, lld, cuda_dtype, ctx->grid.get()); + + // Create Newton-Schulz descriptor + auto ns_desc = CusolverMpNSDescCreate(num_iterations, coefficients, num_coefficients); + + // Set stream on the cuSolverMp handle + NVTE_CHECK_CUSOLVERMP(cusolverMpStreamSet(ctx->cusolver_mp.get(), stream)); + + // Query workspace sizes + size_t wrksp_size_device = 0; + size_t wrksp_size_host = 0; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz_bufferSize( + ctx->cusolver_mp.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), + &wrksp_size_device, &wrksp_size_host)); + + // Allocate/grow device workspace + if (ctx->workspace_size < wrksp_size_device) { + if (ctx->workspace) { + NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); + } + NVTE_CHECK_CUDA(cudaMalloc(&ctx->workspace, wrksp_size_device)); + ctx->workspace_size = wrksp_size_device; + } + + // Allocate host workspace + std::vector workspace_host(wrksp_size_host); + + // Execute Newton-Schulz + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( + ctx->cusolver_mp.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), + ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size())); + + // Synchronize + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); +} diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c542afa393..cbb55e5220 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,6 +18,10 @@ #include #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP +#include +#endif // NVTE_WITH_CUSOLVERMP + #include #include #include @@ -106,6 +110,19 @@ #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP + +#define NVTE_CHECK_CUSOLVERMP(expr) \ + do { \ + const cusolverStatus_t status_NVTE_CHECK_CUSOLVERMP = (expr); \ + if (status_NVTE_CHECK_CUSOLVERMP != CUSOLVER_STATUS_SUCCESS) { \ + NVTE_ERROR("cuSolverMp Error: ", \ + std::to_string(status_NVTE_CHECK_CUSOLVERMP)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUSOLVERMP + #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ From efefa7ec552bd105e717c9b5bc0bc2f2590071ba Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Sun, 8 Feb 2026 23:31:24 +0000 Subject: [PATCH 02/30] [PyTorch] Add Newton-Schulz PyTorch bindings and distributed tests Add PyTorch-level bindings for the cuSolverMp Newton-Schulz inverse square root API introduced in the previous commit. New files: - pytorch/csrc/extensions/newton_schulz.cpp: C++ extension wrapping the C API with PyTorch tensor support - pytorch/newton_schulz.py: Python wrapper that extracts NCCL communicator from torch.distributed ProcessGroup - tests/pytorch/distributed/test_newton_schulz.py: pytest launcher - tests/pytorch/distributed/run_newton_schulz.py: distributed test worker with reference implementation for numerical validation Modified files: - pytorch/csrc/extensions.h: Function declarations - pytorch/csrc/extensions/pybind.cpp: pybind11 registrations - pytorch/__init__.py: Public API export Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../pytorch/distributed/run_newton_schulz.py | 118 ++++++++++++++++++ .../pytorch/distributed/test_newton_schulz.py | 46 +++++++ transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/csrc/extensions.h | 11 ++ .../pytorch/csrc/extensions/newton_schulz.cpp | 55 ++++++++ .../pytorch/csrc/extensions/pybind.cpp | 12 ++ transformer_engine/pytorch/newton_schulz.py | 80 ++++++++++++ 7 files changed, 323 insertions(+) create mode 100644 tests/pytorch/distributed/run_newton_schulz.py create mode 100644 tests/pytorch/distributed/test_newton_schulz.py create mode 100644 transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp create mode 100644 transformer_engine/pytorch/newton_schulz.py diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py new file mode 100644 index 0000000000..bf0e211de6 --- /dev/null +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz test worker. + +Launched via torchrun from test_newton_schulz.py. +""" + +import argparse +import sys + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + + +def newton_schulz_reference(X, num_iterations, coefficients): + """Pure PyTorch reference Newton-Schulz inverse square root. + + Uses the polynomial iteration: X_{k+1} = sum_j coeff[j] * X_k^(2j+1) + for a quintic polynomial with 5 coefficients. + """ + for _ in range(num_iterations): + X2 = X @ X + # Quintic polynomial: c0*X + c1*X^3 + c2*X^5 + c3*X^7 + c4*X^9 + # = X * (c0 + X2 * (c1 + X2 * (c2 + X2 * (c3 + X2 * c4)))) + result = coefficients[4] + result = coefficients[3] + X2 * result + result = coefficients[2] + X2 * result + result = coefficients[1] + X2 * result + result = coefficients[0] + X2 * result + X = X @ result + return X + + +@record +def main(): + parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) + parser.add_argument("--matrix-size", type=int, default=256) + parser.add_argument("--num-iterations", type=int, default=5) + parser.add_argument("--atol", type=float, default=1e-2) + parser.add_argument("--rtol", type=float, default=1e-2) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 + N = args.matrix_size + + # Ensure N is divisible by world_size + assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" + + # Default quintic polynomial coefficients + coefficients = [ + 3069.0 / 1024.0, + -7175.0 / 1024.0, + 9009.0 / 1024.0, + -6435.0 / 1024.0, + 2835.0 / 2048.0, + ] + + # Create a random symmetric positive definite matrix on rank 0 + # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) + # This ensures Newton-Schulz converges + if rank == 0: + torch.manual_seed(42) + Q, _ = torch.linalg.qr(torch.randn(N, N, device="cuda", dtype=torch.float32)) + eigenvalues = torch.rand(N, device="cuda", dtype=torch.float32) * 0.8 + 0.1 + A = Q @ torch.diag(eigenvalues) @ Q.T + A = A.to(dtype) + else: + A = torch.empty(N, N, device="cuda", dtype=dtype) + + # Broadcast the full matrix to all ranks + dist.broadcast(A, src=0) + + # Compute reference on the full matrix + A_ref = A.clone() + result_ref = newton_schulz_reference(A_ref, args.num_iterations, coefficients) + + # Scatter rows to each rank + local_rows = N // world_size + x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous() + + # Run the distributed Newton-Schulz + from transformer_engine.pytorch.newton_schulz import newton_schulz + + group = dist.group.WORLD + newton_schulz(x_local, group, args.num_iterations, coefficients) + + # Gather results + gathered = [torch.empty_like(x_local) for _ in range(world_size)] + dist.all_gather(gathered, x_local) + result_distributed = torch.cat(gathered, dim=0) + + # Check numerical accuracy on rank 0 + if rank == 0: + max_diff = (result_distributed - result_ref).abs().max().item() + rel_diff = max_diff / (result_ref.abs().max().item() + 1e-12) + print(f"Max absolute diff: {max_diff:.6e}", flush=True) + print(f"Max relative diff: {rel_diff:.6e}", flush=True) + + if torch.allclose(result_distributed, result_ref, atol=args.atol, rtol=args.rtol): + print("NUMERICAL CHECK PASSED", flush=True) + else: + print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) + sys.exit(1) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py new file mode 100644 index 0000000000..ccb2d3cf08 --- /dev/null +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed Newton-Schulz inverse square root.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS = min(torch.cuda.device_count(), 4) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("matrix_size", [256]) +def test_newton_schulz(dtype, matrix_size): + """Test distributed Newton-Schulz inverse square root.""" + test_path = TEST_ROOT / "run_newton_schulz.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--dtype={dtype}", + f"--matrix-size={matrix_size}", + "--num-iterations=5", + ] + if dtype == "bfloat16": + test_cmd += ["--atol=5e-2", "--rtol=5e-2"] + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError( + f"Newton-Schulz test failed.\n" + f"stdout: {result.stdout.decode()}\n" + f"stderr: {result.stderr.decode()}" + ) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..ec9caa245f 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -59,6 +59,7 @@ from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.newton_schulz import newton_schulz from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.quantized_tensor import Quantizer diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..0bb20ce1b2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -504,6 +504,17 @@ void nvshmem_finalize(); void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream); +/*************************************************************************************************** + * Newton-Schulz (cuSolverMp) + **************************************************************************************************/ + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); + +void cusolvermp_ctx_destroy(int64_t ctx_ptr); + +void newton_schulz(int64_t ctx_ptr, at::Tensor x, int64_t num_iterations, + std::vector coefficients); + } // namespace transformer_engine::pytorch /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp new file mode 100644 index 0000000000..7d1b44b892 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -0,0 +1,55 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +#ifdef NVTE_WITH_CUSOLVERMP +#include "transformer_engine/newton_schulz.h" +#endif + +namespace transformer_engine::pytorch { + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { +#ifdef NVTE_WITH_CUSOLVERMP + auto comm = reinterpret_cast(nccl_comm_ptr); + auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); + return reinterpret_cast(ctx); +#else + NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); + return 0; +#endif +} + +void cusolvermp_ctx_destroy(int64_t ctx_ptr) { +#ifdef NVTE_WITH_CUSOLVERMP + auto* ctx = reinterpret_cast(ctx_ptr); + nvte_cusolvermp_ctx_destroy(ctx); +#else + NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); +#endif +} + +void newton_schulz(int64_t ctx_ptr, at::Tensor x, int64_t num_iterations, + std::vector coefficients) { +#ifdef NVTE_WITH_CUSOLVERMP + auto* ctx = reinterpret_cast(ctx_ptr); + + // Build NVTETensor from PyTorch tensor + auto x_sizes = x.sizes().vec(); + std::vector shape(x_sizes.begin(), x_sizes.end()); + + auto te_dtype = GetTransformerEngineDType(x.scalar_type()); + TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + nvte_newton_schulz(ctx, x.size(0), x.size(1), x_tensor.data(), num_iterations, + coefficients.data(), static_cast(coefficients.size()), stream); +#else + NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); +#endif +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..d48eccbb88 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -443,6 +443,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, "Fused compute E8M0 scale_inv from amax", py::call_guard()); + // Newton-Schulz (cuSolverMp) + m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create, + "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), + py::arg("rank"), py::call_guard()); + m.def("cusolvermp_ctx_destroy", &transformer_engine::pytorch::cusolvermp_ctx_destroy, + "Destroy cuSolverMp context", py::arg("ctx_ptr"), + py::call_guard()); + m.def("newton_schulz", &transformer_engine::pytorch::newton_schulz, + "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("x"), + py::arg("num_iterations"), py::arg("coefficients"), + py::call_guard()); + // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py new file mode 100644 index 0000000000..a8cf83f774 --- /dev/null +++ b/transformer_engine/pytorch/newton_schulz.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz inverse square root via cuSolverMp.""" + +from typing import List, Optional + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +# Default quintic polynomial coefficients for 5-iteration Newton-Schulz +# from cuSolverMp sample: (3069/1024, -7175/1024, 9009/1024, -6435/1024, 2835/2048) +_DEFAULT_COEFFICIENTS = [ + 3069.0 / 1024.0, + -7175.0 / 1024.0, + 9009.0 / 1024.0, + -6435.0 / 1024.0, + 2835.0 / 2048.0, +] + + +def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: + """Extract the raw NCCL communicator pointer from a PyTorch process group.""" + backend = dist.get_backend(group) + if backend != "nccl": + raise RuntimeError( + f"newton_schulz requires NCCL backend, got '{backend}'" + ) + # Access the NCCL communicator via the internal _get_backend method + nccl_backend = group._get_backend(torch.device("cuda")) + comm = nccl_backend.get_nccl_comm() + return comm + + +def newton_schulz( + x: torch.Tensor, + group: dist.ProcessGroup, + num_iterations: int = 5, + coefficients: Optional[List[float]] = None, +) -> None: + """Compute Newton-Schulz inverse square root in-place on a distributed matrix. + + Parameters + ---------- + x : torch.Tensor + Local part of the distributed matrix (modified in-place). + Must be a 2D CUDA tensor of type float32 or bfloat16. + group : torch.distributed.ProcessGroup + Process group with NCCL backend for distributed communication. + num_iterations : int, optional + Number of Newton-Schulz iterations. Default: 5. + coefficients : list of float, optional + Polynomial coefficients for the Newton-Schulz iteration. + Default: quintic polynomial coefficients from cuSolverMp sample. + """ + if coefficients is None: + coefficients = _DEFAULT_COEFFICIENTS + + if x.dim() != 2: + raise ValueError(f"Expected 2D tensor, got {x.dim()}D") + if not x.is_cuda: + raise ValueError("Input tensor must be on CUDA device") + + nccl_comm_ptr = _get_nccl_comm_ptr(group) + nranks = dist.get_world_size(group) + rank = dist.get_rank(group) + + # Global matrix dimensions + m = x.size(0) * nranks # rows are distributed across ranks + n = x.size(1) + + ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank) + try: + tex.newton_schulz(ctx_ptr, x, num_iterations, coefficients) + finally: + tex.cusolvermp_ctx_destroy(ctx_ptr) From 02299b31da310e8a7f3130dff3792e5b7d435328 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Sun, 8 Feb 2026 23:47:19 +0000 Subject: [PATCH 03/30] [Common] Fix cuSolverMp API signatures in Newton-Schulz implementation Fix API mismatches discovered during compilation: - cusolverMpCreate takes (handle*, deviceId, stream), not (handle*, stream) - cusolverMpCreateDeviceGrid takes handle as first arg with different parameter order - Use cusolverMpGridMapping_t (not cusolverMpGridLayout_t) and CUSOLVERMP_GRID_MAPPING_COL_MAJOR - cusolverMpCreateMatrixDesc has different parameter order: (desc*, grid, dtype, M, N, MB, NB, RSRC, CSRC, LLD) - cusolverMpNewtonSchulzDescriptorCreate takes only (nsDesc*) with no iteration/coefficient args - No cusolverMpStreamSet exists; create handle per-call with user stream - cusolverMpNewtonSchulz requires computeType and info parameters - Switch from generic template RAII to explicit deleter structs Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../common/newton_schulz/newton_schulz.cpp | 144 ++++++++---------- 1 file changed, 66 insertions(+), 78 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index f70986a122..2eb222363f 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -19,72 +19,61 @@ using namespace transformer_engine; namespace { -template -auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { - using Handle = std::remove_pointer_t; - HandlePtr raw{}; - NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); - return std::unique_ptr(raw, destroy_fn); -} +// RAII wrapper types for cuSolverMp handles -using CudaStream = - std::unique_ptr, decltype(&cudaStreamDestroy)>; +struct CusolverMpHandleDeleter { + void operator()(cusolverMpHandle_t handle) const { cusolverMpDestroy(handle); } +}; +using CusolverMpHandle = std::unique_ptr, + CusolverMpHandleDeleter>; -CudaStream CudaStreamCreate() { - return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); -} +struct CusolverMpGridDeleter { + void operator()(cusolverMpGrid_t grid) const { cusolverMpDestroyGrid(grid); } +}; +using CusolverMpGrid = std::unique_ptr, + CusolverMpGridDeleter>; -template -auto CreateWithCusolverMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { - using Handle = std::remove_pointer_t; - HandlePtr raw{}; - if constexpr (raw_last) { - NVTE_CHECK_CUSOLVERMP(create_fn(std::forward(args)..., &raw)); - } else { - NVTE_CHECK_CUSOLVERMP(create_fn(&raw, std::forward(args)...)); - } - return std::unique_ptr(raw, destroy_fn); -} +struct CusolverMpMatrixDescDeleter { + void operator()(cusolverMpMatrixDescriptor_t desc) const { cusolverMpDestroyMatrixDesc(desc); } +}; +using CusolverMpMatrixDesc = std::unique_ptr, + CusolverMpMatrixDescDeleter>; -using CusolverMp = - std::unique_ptr, decltype(&cusolverMpDestroy)>; +struct CusolverMpNSDescDeleter { + void operator()(cusolverMpNewtonSchulzDescriptor_t desc) const { + cusolverMpNewtonSchulzDescriptorDestroy(desc); + } +}; +using CusolverMpNSDesc = std::unique_ptr, + CusolverMpNSDescDeleter>; -CusolverMp CusolverMpCreate(cudaStream_t stream) { - return CreateWithCusolverMpCheck(cusolverMpCreate, cusolverMpDestroy, - stream); +CusolverMpHandle MakeCusolverMpHandle(int device_id, cudaStream_t stream) { + cusolverMpHandle_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreate(&raw, device_id, stream)); + return CusolverMpHandle(raw); } -using CusolverMpGrid = - std::unique_ptr, decltype(&cusolverMpDestroyGrid)>; - -CusolverMpGrid CusolverMpGridCreate(int64_t nprow, int64_t npcol, - cusolverMpGridLayout_t layout, ncclComm_t comm) { - return CreateWithCusolverMpCheck( - cusolverMpCreateDeviceGrid, cusolverMpDestroyGrid, nprow, npcol, layout, comm); +CusolverMpGrid MakeCusolverMpGrid(cusolverMpHandle_t handle, ncclComm_t comm, + int32_t nprow, int32_t npcol, + cusolverMpGridMapping_t mapping) { + cusolverMpGrid_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreateDeviceGrid(handle, &raw, comm, nprow, npcol, mapping)); + return CusolverMpGrid(raw); } -using CusolverMpMatrixDesc = - std::unique_ptr, - decltype(&cusolverMpDestroyMatrixDesc)>; - -CusolverMpMatrixDesc CusolverMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, - int64_t rsrc, int64_t csrc, int64_t lld, - cudaDataType_t type, cusolverMpGrid_t grid) { - return CreateWithCusolverMpCheck( - cusolverMpCreateMatrixDesc, cusolverMpDestroyMatrixDesc, m, n, mb, nb, rsrc, csrc, lld, type, - grid); +CusolverMpMatrixDesc MakeCusolverMpMatrixDesc(cusolverMpGrid_t grid, cudaDataType_t dtype, + int64_t m, int64_t n, int64_t mb, int64_t nb, + uint32_t rsrc, uint32_t csrc, int64_t lld) { + cusolverMpMatrixDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP( + cusolverMpCreateMatrixDesc(&raw, grid, dtype, m, n, mb, nb, rsrc, csrc, lld)); + return CusolverMpMatrixDesc(raw); } -using CusolverMpNSDesc = - std::unique_ptr, - decltype(&cusolverMpNewtonSchulzDescriptorDestroy)>; - -CusolverMpNSDesc CusolverMpNSDescCreate(int64_t num_iterations, const float* coefficients, - int64_t num_coefficients) { - return CreateWithCusolverMpCheck( - cusolverMpNewtonSchulzDescriptorCreate, cusolverMpNewtonSchulzDescriptorDestroy, - num_iterations, coefficients, num_coefficients); +CusolverMpNSDesc MakeCusolverMpNSDesc() { + cusolverMpNewtonSchulzDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulzDescriptorCreate(&raw)); + return CusolverMpNSDesc(raw); } } // namespace @@ -93,29 +82,16 @@ struct NVTECusolverMpCtx { int64_t nranks; int64_t rank; ncclComm_t comm; - CudaStream stream; - CusolverMp cusolver_mp; - CusolverMpGrid grid; void* workspace; size_t workspace_size; }; NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { NVTE_API_CALL(nvte_cusolvermp_ctx_create); - auto stream = CudaStreamCreate(); - auto cusolver_mp = CusolverMpCreate(stream.get()); - - // 1D row partition: nranks x 1, column-major - auto grid = - CusolverMpGridCreate(nranks, 1, CUSOLVERMP_GRID_LAYOUT_COL_MAJOR, comm); - return new NVTECusolverMpCtx{ .nranks = nranks, .rank = rank, .comm = comm, - .stream = std::move(stream), - .cusolver_mp = std::move(cusolver_mp), - .grid = std::move(grid), .workspace = nullptr, .workspace_size = 0, }; @@ -124,7 +100,7 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { NVTE_API_CALL(nvte_cusolvermp_ctx_destroy); if (ctx->workspace) { - NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); + cudaFree(ctx->workspace); } delete ctx; } @@ -135,6 +111,17 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor NVTE_API_CALL(nvte_newton_schulz); const auto* t = convertNVTETensorCheck(x); + // Get current device + int device_id{}; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + + // Create cuSolverMp handle bound to the caller's stream + auto handle = MakeCusolverMpHandle(device_id, stream); + + // 1D row partition: nranks x 1, column-major + auto grid = MakeCusolverMpGrid(handle.get(), ctx->comm, ctx->nranks, 1, + CUSOLVERMP_GRID_MAPPING_COL_MAJOR); + // Block size for ScaLAPACK-style distribution const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; const int64_t nb = n; @@ -146,20 +133,17 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor const cudaDataType_t cuda_dtype = get_cuda_dtype(t->dtype()); // Create matrix descriptor - auto mat_desc = CusolverMpMatrixDescCreate(m, n, mb, nb, 0, 0, lld, cuda_dtype, ctx->grid.get()); + auto mat_desc = MakeCusolverMpMatrixDesc(grid.get(), cuda_dtype, m, n, mb, nb, 0, 0, lld); // Create Newton-Schulz descriptor - auto ns_desc = CusolverMpNSDescCreate(num_iterations, coefficients, num_coefficients); - - // Set stream on the cuSolverMp handle - NVTE_CHECK_CUSOLVERMP(cusolverMpStreamSet(ctx->cusolver_mp.get(), stream)); + auto ns_desc = MakeCusolverMpNSDesc(); // Query workspace sizes size_t wrksp_size_device = 0; size_t wrksp_size_host = 0; NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz_bufferSize( - ctx->cusolver_mp.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), - &wrksp_size_device, &wrksp_size_host)); + handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, &wrksp_size_device, &wrksp_size_host)); // Allocate/grow device workspace if (ctx->workspace_size < wrksp_size_device) { @@ -174,10 +158,14 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor std::vector workspace_host(wrksp_size_host); // Execute Newton-Schulz + int info = 0; NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( - ctx->cusolver_mp.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), - ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size())); + handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), + workspace_host.size(), &info)); // Synchronize NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } From ed6c21f90a2a08fc2889d9b61b829bf485f754c7 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Sun, 8 Feb 2026 23:47:25 +0000 Subject: [PATCH 04/30] [PyTorch] Propagate NVTE_WITH_CUSOLVERMP define to PyTorch extension build Add NVTE_WITH_CUSOLVERMP compiler define and cusolverMp include/library paths to the PyTorch C++ extension build, following the same pattern as NVTE_UB_WITH_MPI and NVTE_ENABLE_NVSHMEM. Without this, the #ifdef NVTE_WITH_CUSOLVERMP guards in the PyTorch extension code would never be active since the define was only set as PRIVATE in the CMake build for the common library. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- build_tools/pytorch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index fdfdee9b1c..4530f1c4d7 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -87,6 +87,13 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): + cusolvermp_home = Path(os.getenv("CUSOLVERMP_HOME", "/usr")) + include_dirs.append(cusolvermp_home / "include") + library_dirs.append(cusolvermp_home / "lib") + libraries.append("cusolverMp") + cxx_flags.append("-DNVTE_WITH_CUSOLVERMP") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] From fbc1c4e807ab4c05be5ce3eb2b9e2ff07762085d Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 9 Feb 2026 00:06:39 +0000 Subject: [PATCH 05/30] [PyTorch] Fix NCCL comm extraction and pass global dims to Newton-Schulz Two fixes: - Use ProcessGroupNCCL._comm_ptr() to extract the raw NCCL communicator pointer instead of the non-existent get_nccl_comm() method - Pass global matrix dimensions (m, n) from Python to C++ instead of using local tensor dimensions, which would produce incorrect ScaLAPACK block sizes in the distributed computation Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- transformer_engine/pytorch/csrc/extensions.h | 4 ++-- .../pytorch/csrc/extensions/newton_schulz.cpp | 8 ++++---- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 4 ++-- transformer_engine/pytorch/newton_schulz.py | 6 ++---- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0bb20ce1b2..0e9d2c9318 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -512,8 +512,8 @@ int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); void cusolvermp_ctx_destroy(int64_t ctx_ptr); -void newton_schulz(int64_t ctx_ptr, at::Tensor x, int64_t num_iterations, - std::vector coefficients); +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, + int64_t num_iterations, std::vector coefficients); } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 7d1b44b892..06324dabf9 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -32,8 +32,8 @@ void cusolvermp_ctx_destroy(int64_t ctx_ptr) { #endif } -void newton_schulz(int64_t ctx_ptr, at::Tensor x, int64_t num_iterations, - std::vector coefficients) { +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, + int64_t num_iterations, std::vector coefficients) { #ifdef NVTE_WITH_CUSOLVERMP auto* ctx = reinterpret_cast(ctx_ptr); @@ -45,8 +45,8 @@ void newton_schulz(int64_t ctx_ptr, at::Tensor x, int64_t num_iterations, TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); auto stream = at::cuda::getCurrentCUDAStream().stream(); - nvte_newton_schulz(ctx, x.size(0), x.size(1), x_tensor.data(), num_iterations, - coefficients.data(), static_cast(coefficients.size()), stream); + nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), + static_cast(coefficients.size()), stream); #else NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); #endif diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d48eccbb88..d393ed0d4a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -451,8 +451,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Destroy cuSolverMp context", py::arg("ctx_ptr"), py::call_guard()); m.def("newton_schulz", &transformer_engine::pytorch::newton_schulz, - "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("x"), - py::arg("num_iterations"), py::arg("coefficients"), + "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), + py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), py::call_guard()); // Comm+GEMM Overlap diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index a8cf83f774..f338003168 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -30,10 +30,8 @@ def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: raise RuntimeError( f"newton_schulz requires NCCL backend, got '{backend}'" ) - # Access the NCCL communicator via the internal _get_backend method nccl_backend = group._get_backend(torch.device("cuda")) - comm = nccl_backend.get_nccl_comm() - return comm + return nccl_backend._comm_ptr() def newton_schulz( @@ -75,6 +73,6 @@ def newton_schulz( ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank) try: - tex.newton_schulz(ctx_ptr, x, num_iterations, coefficients) + tex.newton_schulz(ctx_ptr, m, n, x, num_iterations, coefficients) finally: tex.cusolvermp_ctx_destroy(ctx_ptr) From f8b23cc759cc76ba73676730a2fbc43fe651c0bc Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 01:56:01 +0000 Subject: [PATCH 06/30] [Common] Cache cuSolverMp handle and grid in Newton-Schulz context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cuSolverMp handle and grid creation are expensive operations. Move them from per-call creation in nvte_newton_schulz into the NVTECusolverMpCtx, which is their natural home — the context exists to encapsulate the grid. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../transformer_engine/newton_schulz.h | 4 +- .../common/newton_schulz/newton_schulz.cpp | 40 +++++++++---------- .../pytorch/csrc/extensions/newton_schulz.cpp | 3 +- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h index b540b9db54..bd0e2a815e 100644 --- a/transformer_engine/common/include/transformer_engine/newton_schulz.h +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -33,8 +33,10 @@ typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; * \param[in] comm NCCL communicator. * \param[in] nranks Number of ranks. * \param[in] rank Local rank. + * \param[in] stream CUDA stream. */ -NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank, + cudaStream_t stream); /*! \brief Destroy a cuSolverMp context. * diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 2eb222363f..a02e01372a 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -17,9 +17,8 @@ using namespace transformer_engine; -namespace { - -// RAII wrapper types for cuSolverMp handles +// RAII wrapper types for cuSolverMp handles (outside anonymous namespace because +// CusolverMpHandle and CusolverMpGrid are used in the NVTECusolverMpCtx struct) struct CusolverMpHandleDeleter { void operator()(cusolverMpHandle_t handle) const { cusolverMpDestroy(handle); } @@ -33,6 +32,8 @@ struct CusolverMpGridDeleter { using CusolverMpGrid = std::unique_ptr, CusolverMpGridDeleter>; +namespace { + struct CusolverMpMatrixDescDeleter { void operator()(cusolverMpMatrixDescriptor_t desc) const { cusolverMpDestroyMatrixDesc(desc); } }; @@ -81,17 +82,27 @@ CusolverMpNSDesc MakeCusolverMpNSDesc() { struct NVTECusolverMpCtx { int64_t nranks; int64_t rank; - ncclComm_t comm; + CusolverMpHandle handle; + CusolverMpGrid grid; void* workspace; size_t workspace_size; }; -NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank, + cudaStream_t stream) { NVTE_API_CALL(nvte_cusolvermp_ctx_create); + int device_id{}; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + + auto handle = MakeCusolverMpHandle(device_id, stream); + auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, + CUSOLVERMP_GRID_MAPPING_COL_MAJOR); + return new NVTECusolverMpCtx{ .nranks = nranks, .rank = rank, - .comm = comm, + .handle = std::move(handle), + .grid = std::move(grid), .workspace = nullptr, .workspace_size = 0, }; @@ -111,17 +122,6 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor NVTE_API_CALL(nvte_newton_schulz); const auto* t = convertNVTETensorCheck(x); - // Get current device - int device_id{}; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - - // Create cuSolverMp handle bound to the caller's stream - auto handle = MakeCusolverMpHandle(device_id, stream); - - // 1D row partition: nranks x 1, column-major - auto grid = MakeCusolverMpGrid(handle.get(), ctx->comm, ctx->nranks, 1, - CUSOLVERMP_GRID_MAPPING_COL_MAJOR); - // Block size for ScaLAPACK-style distribution const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; const int64_t nb = n; @@ -133,7 +133,7 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor const cudaDataType_t cuda_dtype = get_cuda_dtype(t->dtype()); // Create matrix descriptor - auto mat_desc = MakeCusolverMpMatrixDesc(grid.get(), cuda_dtype, m, n, mb, nb, 0, 0, lld); + auto mat_desc = MakeCusolverMpMatrixDesc(ctx->grid.get(), cuda_dtype, m, n, mb, nb, 0, 0, lld); // Create Newton-Schulz descriptor auto ns_desc = MakeCusolverMpNSDesc(); @@ -142,7 +142,7 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor size_t wrksp_size_device = 0; size_t wrksp_size_host = 0; NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz_bufferSize( - handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + ctx->handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, coefficients, CUDA_R_32F, &wrksp_size_device, &wrksp_size_host)); // Allocate/grow device workspace @@ -160,7 +160,7 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor // Execute Newton-Schulz int info = 0; NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( - handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + ctx->handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size(), &info)); diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 06324dabf9..07a53f0c4e 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -15,7 +15,8 @@ namespace transformer_engine::pytorch { int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { #ifdef NVTE_WITH_CUSOLVERMP auto comm = reinterpret_cast(nccl_comm_ptr); - auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank, stream); return reinterpret_cast(ctx); #else NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); From 8dbdcbbd858dbf59e0482663cc1663b90b8d7d48 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 04:06:21 +0000 Subject: [PATCH 07/30] [Common] Create dedicated CUDA stream in Newton-Schulz context cuSolverMp cannot work with the default CUDA stream. Create a dedicated stream inside nvte_cusolvermp_ctx_create and remove the stream parameter from both C API functions since the context now owns its stream. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../include/transformer_engine/newton_schulz.h | 10 +++++----- .../common/newton_schulz/newton_schulz.cpp | 16 ++++++++++++---- .../pytorch/csrc/extensions/newton_schulz.cpp | 6 ++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h index bd0e2a815e..1336015270 100644 --- a/transformer_engine/common/include/transformer_engine/newton_schulz.h +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -29,14 +29,15 @@ extern "C" { typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; /*! \brief Create a cuSolverMp context for Newton-Schulz operations. + * + * Creates a dedicated CUDA stream internally (cuSolverMp requires a + * non-default stream). * * \param[in] comm NCCL communicator. * \param[in] nranks Number of ranks. * \param[in] rank Local rank. - * \param[in] stream CUDA stream. */ -NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank, - cudaStream_t stream); +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); /*! \brief Destroy a cuSolverMp context. * @@ -57,11 +58,10 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial * degree used internally by cuSolverMp). * \param[in] num_coefficients Number of elements in the coefficients array. - * \param[in] stream CUDA stream. */ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, int64_t num_iterations, const float* coefficients, - int64_t num_coefficients, cudaStream_t stream); + int64_t num_coefficients); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index a02e01372a..394e09a6b3 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -82,18 +82,21 @@ CusolverMpNSDesc MakeCusolverMpNSDesc() { struct NVTECusolverMpCtx { int64_t nranks; int64_t rank; + cudaStream_t stream; CusolverMpHandle handle; CusolverMpGrid grid; void* workspace; size_t workspace_size; }; -NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank, - cudaStream_t stream) { +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { NVTE_API_CALL(nvte_cusolvermp_ctx_create); int device_id{}; NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + auto handle = MakeCusolverMpHandle(device_id, stream); auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, CUSOLVERMP_GRID_MAPPING_COL_MAJOR); @@ -101,6 +104,7 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r return new NVTECusolverMpCtx{ .nranks = nranks, .rank = rank, + .stream = stream, .handle = std::move(handle), .grid = std::move(grid), .workspace = nullptr, @@ -113,12 +117,16 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { if (ctx->workspace) { cudaFree(ctx->workspace); } + // Destroy handle and grid before the stream they depend on + ctx->handle.reset(); + ctx->grid.reset(); + cudaStreamDestroy(ctx->stream); delete ctx; } void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, int64_t num_iterations, const float* coefficients, - int64_t num_coefficients, cudaStream_t stream) { + int64_t num_coefficients) { NVTE_API_CALL(nvte_newton_schulz); const auto* t = convertNVTETensorCheck(x); @@ -165,7 +173,7 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor workspace_host.size(), &info)); // Synchronize - NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(ctx->stream)); NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 07a53f0c4e..9892fa1802 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -15,8 +15,7 @@ namespace transformer_engine::pytorch { int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { #ifdef NVTE_WITH_CUSOLVERMP auto comm = reinterpret_cast(nccl_comm_ptr); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank, stream); + auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); return reinterpret_cast(ctx); #else NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); @@ -45,9 +44,8 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, auto te_dtype = GetTransformerEngineDType(x.scalar_type()); TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); - auto stream = at::cuda::getCurrentCUDAStream().stream(); nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), - static_cast(coefficients.size()), stream); + static_cast(coefficients.size())); #else NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); #endif From 1c01a9d5dc6db06cb36e5bec1451026d901b2692 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 06:10:49 +0000 Subject: [PATCH 08/30] [Common] Fix Newton-Schulz zero output with event-based stream sync The internal dedicated stream was reading the input tensor before the caller's stream had finished producing it, resulting in all-zero output. Add event-based synchronisation: the internal stream waits for the caller's input to be ready, and the caller's stream waits for the output to be written. Replaces the blocking cudaStreamSynchronize. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../transformer_engine/newton_schulz.h | 4 +++- .../common/newton_schulz/newton_schulz.cpp | 19 ++++++++++++++++--- .../pytorch/csrc/extensions/newton_schulz.cpp | 3 ++- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h index 1336015270..543c3f63ba 100644 --- a/transformer_engine/common/include/transformer_engine/newton_schulz.h +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -58,10 +58,12 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial * degree used internally by cuSolverMp). * \param[in] num_coefficients Number of elements in the coefficients array. + * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. + * Used for event-based synchronisation with the internal stream. */ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, int64_t num_iterations, const float* coefficients, - int64_t num_coefficients); + int64_t num_coefficients, cudaStream_t caller_stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 394e09a6b3..79939405b8 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -126,10 +126,18 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, int64_t num_iterations, const float* coefficients, - int64_t num_coefficients) { + int64_t num_coefficients, cudaStream_t caller_stream) { NVTE_API_CALL(nvte_newton_schulz); const auto* t = convertNVTETensorCheck(x); + // Make the internal stream wait for the caller's stream so that + // the input tensor is ready before cuSolverMp reads it. + cudaEvent_t input_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&input_ready)); + NVTE_CHECK_CUDA(cudaEventRecord(input_ready, caller_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, input_ready)); + NVTE_CHECK_CUDA(cudaEventDestroy(input_ready)); + // Block size for ScaLAPACK-style distribution const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; const int64_t nb = n; @@ -172,8 +180,13 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size(), &info)); - // Synchronize - NVTE_CHECK_CUDA(cudaStreamSynchronize(ctx->stream)); + // Make the caller's stream wait for the internal stream so that + // the in-place result is visible to subsequent work on caller_stream. + cudaEvent_t output_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&output_ready)); + NVTE_CHECK_CUDA(cudaEventRecord(output_ready, ctx->stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, output_ready)); + NVTE_CHECK_CUDA(cudaEventDestroy(output_ready)); NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 9892fa1802..34d9bb87f8 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -44,8 +44,9 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, auto te_dtype = GetTransformerEngineDType(x.scalar_type()); TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); + auto caller_stream = at::cuda::getCurrentCUDAStream().stream(); nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), - static_cast(coefficients.size())); + static_cast(coefficients.size()), caller_stream); #else NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); #endif From 1d8115ddd45c098c15114e47c46c4ce3c2c280b6 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 06:22:40 +0000 Subject: [PATCH 09/30] [Common] Fix Newton-Schulz NaNs by keeping host workspace alive cuSolverMp is asynchronous and uses the host workspace during multi-GPU execution. The event-based output sync did not block the host, so the local workspace_host vector was destroyed while the GPU was still reading from it. Restore cudaStreamSynchronize to ensure the host workspace remains valid for the full duration of the operation. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../common/newton_schulz/newton_schulz.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 79939405b8..b39eff3c2e 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -180,13 +180,11 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size(), &info)); - // Make the caller's stream wait for the internal stream so that - // the in-place result is visible to subsequent work on caller_stream. - cudaEvent_t output_ready{}; - NVTE_CHECK_CUDA(cudaEventCreate(&output_ready)); - NVTE_CHECK_CUDA(cudaEventRecord(output_ready, ctx->stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, output_ready)); - NVTE_CHECK_CUDA(cudaEventDestroy(output_ready)); + // Host-side sync: cuSolverMp is asynchronous and uses the host + // workspace during execution for multi-GPU coordination. We must + // block until the stream finishes so that workspace_host (a local + // vector) is not destroyed while the GPU is still reading from it. + NVTE_CHECK_CUDA(cudaStreamSynchronize(ctx->stream)); NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } From fcb4d33dad72320d465bf6ba2d6b2b8223398ed5 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 19:57:41 +0000 Subject: [PATCH 10/30] [Common] Cache CUDA event in Newton-Schulz context Avoid creating and destroying a cudaEvent_t on every nvte_newton_schulz call by making it a persistent member of NVTECusolverMpCtx, matching the existing pattern for the stream. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../common/newton_schulz/newton_schulz.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index b39eff3c2e..90c553f066 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -83,6 +83,7 @@ struct NVTECusolverMpCtx { int64_t nranks; int64_t rank; cudaStream_t stream; + cudaEvent_t event; CusolverMpHandle handle; CusolverMpGrid grid; void* workspace; @@ -97,6 +98,9 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r cudaStream_t stream{}; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + cudaEvent_t event{}; + NVTE_CHECK_CUDA(cudaEventCreate(&event)); + auto handle = MakeCusolverMpHandle(device_id, stream); auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, CUSOLVERMP_GRID_MAPPING_COL_MAJOR); @@ -105,6 +109,7 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r .nranks = nranks, .rank = rank, .stream = stream, + .event = event, .handle = std::move(handle), .grid = std::move(grid), .workspace = nullptr, @@ -120,6 +125,7 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { // Destroy handle and grid before the stream they depend on ctx->handle.reset(); ctx->grid.reset(); + cudaEventDestroy(ctx->event); cudaStreamDestroy(ctx->stream); delete ctx; } @@ -132,11 +138,8 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor // Make the internal stream wait for the caller's stream so that // the input tensor is ready before cuSolverMp reads it. - cudaEvent_t input_ready{}; - NVTE_CHECK_CUDA(cudaEventCreate(&input_ready)); - NVTE_CHECK_CUDA(cudaEventRecord(input_ready, caller_stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, input_ready)); - NVTE_CHECK_CUDA(cudaEventDestroy(input_ready)); + NVTE_CHECK_CUDA(cudaEventRecord(ctx->event, caller_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, ctx->event)); // Block size for ScaLAPACK-style distribution const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; From b4422b9103faa1439945228918a067357e5ec6f4 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 21:29:36 +0000 Subject: [PATCH 11/30] [Common] Use separate in/out events for Newton-Schulz stream sync Replace single event with in_ready and out_ready events. After the cuSolverMp call, record out_ready on the internal stream and make the caller's stream wait on it, ensuring the output tensor is ready before the caller uses it. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../common/newton_schulz/newton_schulz.cpp | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 90c553f066..a4a6eb7ae4 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -83,7 +83,8 @@ struct NVTECusolverMpCtx { int64_t nranks; int64_t rank; cudaStream_t stream; - cudaEvent_t event; + cudaEvent_t in_ready; + cudaEvent_t out_ready; CusolverMpHandle handle; CusolverMpGrid grid; void* workspace; @@ -98,8 +99,10 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r cudaStream_t stream{}; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - cudaEvent_t event{}; - NVTE_CHECK_CUDA(cudaEventCreate(&event)); + cudaEvent_t in_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&in_ready)); + cudaEvent_t out_ready{}; + NVTE_CHECK_CUDA(cudaEventCreate(&out_ready)); auto handle = MakeCusolverMpHandle(device_id, stream); auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, @@ -109,7 +112,8 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r .nranks = nranks, .rank = rank, .stream = stream, - .event = event, + .in_ready = in_ready, + .out_ready = out_ready, .handle = std::move(handle), .grid = std::move(grid), .workspace = nullptr, @@ -125,7 +129,8 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { // Destroy handle and grid before the stream they depend on ctx->handle.reset(); ctx->grid.reset(); - cudaEventDestroy(ctx->event); + cudaEventDestroy(ctx->in_ready); + cudaEventDestroy(ctx->out_ready); cudaStreamDestroy(ctx->stream); delete ctx; } @@ -138,8 +143,8 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor // Make the internal stream wait for the caller's stream so that // the input tensor is ready before cuSolverMp reads it. - NVTE_CHECK_CUDA(cudaEventRecord(ctx->event, caller_stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, ctx->event)); + NVTE_CHECK_CUDA(cudaEventRecord(ctx->in_ready, caller_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream, ctx->in_ready)); // Block size for ScaLAPACK-style distribution const int64_t mb = (m + ctx->nranks - 1) / ctx->nranks; @@ -183,6 +188,11 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size(), &info)); + // Make the caller's stream wait for the internal stream so that + // the output tensor is ready before the caller uses it. + NVTE_CHECK_CUDA(cudaEventRecord(ctx->out_ready, ctx->stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, ctx->out_ready)); + // Host-side sync: cuSolverMp is asynchronous and uses the host // workspace during execution for multi-GPU coordination. We must // block until the stream finishes so that workspace_host (a local From e8c51f80f396ff8971536b484d39cbb8d7537c6f Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 22:43:55 +0000 Subject: [PATCH 12/30] Correct coefficients Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_newton_schulz.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index bf0e211de6..d6b537b7b7 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -55,14 +55,11 @@ def main(): # Ensure N is divisible by world_size assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" - # Default quintic polynomial coefficients - coefficients = [ - 3069.0 / 1024.0, - -7175.0 / 1024.0, - 9009.0 / 1024.0, - -6435.0 / 1024.0, - 2835.0 / 2048.0, - ] + coefficients = [4.0848, -6.8946, 2.9270, + 3.9505, -6.3029, 2.6377, + 3.7418, -5.5913, 2.3037, + 2.8769, -3.1427, 1.2046, + 2.8366, -3.0525, 1.2012] # Create a random symmetric positive definite matrix on rank 0 # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) From 412445cdc986929671427f543472b3d0b6c8c4b6 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 22:49:49 +0000 Subject: [PATCH 13/30] No stream synchronize Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/newton_schulz/newton_schulz.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index a4a6eb7ae4..fa1be76b20 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -193,11 +193,5 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor NVTE_CHECK_CUDA(cudaEventRecord(ctx->out_ready, ctx->stream)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, ctx->out_ready)); - // Host-side sync: cuSolverMp is asynchronous and uses the host - // workspace during execution for multi-GPU coordination. We must - // block until the stream finishes so that workspace_host (a local - // vector) is not destroyed while the GPU is still reading from it. - NVTE_CHECK_CUDA(cudaStreamSynchronize(ctx->stream)); - NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } From 9645073a8d6c8855a0fac773b031c71a7642347c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 18 Feb 2026 23:08:33 +0000 Subject: [PATCH 14/30] [Test] Verify Newton-Schulz result with XAX=I identity check Replace reference-comparison test with a direct arithmetic check: if X is the inverse square root of A, then X @ A @ X must equal the identity matrix. This is more robust and removes the need for a separate reference implementation. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- .../pytorch/distributed/run_newton_schulz.py | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index d6b537b7b7..15834f2830 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -15,25 +15,6 @@ from torch.distributed.elastic.multiprocessing.errors import record -def newton_schulz_reference(X, num_iterations, coefficients): - """Pure PyTorch reference Newton-Schulz inverse square root. - - Uses the polynomial iteration: X_{k+1} = sum_j coeff[j] * X_k^(2j+1) - for a quintic polynomial with 5 coefficients. - """ - for _ in range(num_iterations): - X2 = X @ X - # Quintic polynomial: c0*X + c1*X^3 + c2*X^5 + c3*X^7 + c4*X^9 - # = X * (c0 + X2 * (c1 + X2 * (c2 + X2 * (c3 + X2 * c4)))) - result = coefficients[4] - result = coefficients[3] + X2 * result - result = coefficients[2] + X2 * result - result = coefficients[1] + X2 * result - result = coefficients[0] + X2 * result - X = X @ result - return X - - @record def main(): parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") @@ -76,9 +57,8 @@ def main(): # Broadcast the full matrix to all ranks dist.broadcast(A, src=0) - # Compute reference on the full matrix - A_ref = A.clone() - result_ref = newton_schulz_reference(A_ref, args.num_iterations, coefficients) + # Keep a copy of the original matrix for verification + A_orig = A.clone() # Scatter rows to each rank local_rows = N // world_size @@ -93,16 +73,16 @@ def main(): # Gather results gathered = [torch.empty_like(x_local) for _ in range(world_size)] dist.all_gather(gathered, x_local) - result_distributed = torch.cat(gathered, dim=0) + X = torch.cat(gathered, dim=0) - # Check numerical accuracy on rank 0 + # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix if rank == 0: - max_diff = (result_distributed - result_ref).abs().max().item() - rel_diff = max_diff / (result_ref.abs().max().item() + 1e-12) - print(f"Max absolute diff: {max_diff:.6e}", flush=True) - print(f"Max relative diff: {rel_diff:.6e}", flush=True) + XAX = X @ A_orig @ X + I = torch.eye(N, device=XAX.device, dtype=XAX.dtype) + max_diff = (XAX - I).abs().max().item() + print(f"Max |X @ A @ X - I|: {max_diff:.6e}", flush=True) - if torch.allclose(result_distributed, result_ref, atol=args.atol, rtol=args.rtol): + if torch.allclose(XAX, I, atol=args.atol, rtol=args.rtol): print("NUMERICAL CHECK PASSED", flush=True) else: print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) From dd1dd0b43ed38834ff0879938b46f28ca66a155a Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 19 Feb 2026 21:36:56 +0000 Subject: [PATCH 15/30] Change test - it approximates orthogonal matrix, not inverse square root Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_newton_schulz.py | 10 +++++----- tests/pytorch/distributed/test_newton_schulz.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index 15834f2830..7ff57cb394 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -77,12 +77,12 @@ def main(): # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix if rank == 0: - XAX = X @ A_orig @ X - I = torch.eye(N, device=XAX.device, dtype=XAX.dtype) - max_diff = (XAX - I).abs().max().item() - print(f"Max |X @ A @ X - I|: {max_diff:.6e}", flush=True) + XXT = X @ X.t() + I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) + max_diff = (XXT - I).abs().max().item() + print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) - if torch.allclose(XAX, I, atol=args.atol, rtol=args.rtol): + if torch.allclose(XXT, I, atol=args.atol, rtol=args.rtol): print("NUMERICAL CHECK PASSED", flush=True) else: print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index ccb2d3cf08..fd253036cf 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -15,7 +15,7 @@ pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS = min(torch.cuda.device_count(), 4) +NUM_PROCS = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] From 85d33fb1f509fd55ff61cc3ebf821998b670f7e1 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 19 Feb 2026 22:12:20 +0000 Subject: [PATCH 16/30] Generalize number of iterations in tests Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_newton_schulz.py | 11 ++++++----- tests/pytorch/distributed/test_newton_schulz.py | 5 +++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index 7ff57cb394..223aa23602 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -36,11 +36,12 @@ def main(): # Ensure N is divisible by world_size assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" - coefficients = [4.0848, -6.8946, 2.9270, - 3.9505, -6.3029, 2.6377, - 3.7418, -5.5913, 2.3037, - 2.8769, -3.1427, 1.2046, - 2.8366, -3.0525, 1.2012] + quintic_coefficients = [4.0848, -6.8946, 2.9270, + 3.9505, -6.3029, 2.6377, + 3.7418, -5.5913, 2.3037, + 2.8769, -3.1427, 1.2046, + 2.8366, -3.0525, 1.2012] + coefficients = quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations # Create a random symmetric positive definite matrix on rank 0 # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index fd253036cf..95b1335b32 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -21,14 +21,15 @@ @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) @pytest.mark.parametrize("matrix_size", [256]) -def test_newton_schulz(dtype, matrix_size): +@pytest.mark.parametrize("num_iterations", [5, 15]) +def test_newton_schulz(dtype, matrix_size, num_iterations): """Test distributed Newton-Schulz inverse square root.""" test_path = TEST_ROOT / "run_newton_schulz.py" test_cmd = LAUNCH_CMD + [ str(test_path), f"--dtype={dtype}", f"--matrix-size={matrix_size}", - "--num-iterations=5", + f"--num-iterations={num_iterations}", ] if dtype == "bfloat16": test_cmd += ["--atol=5e-2", "--rtol=5e-2"] From a011231eaa0b3bea4785a5fd6365e603000915b9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 25 Feb 2026 21:56:07 +0000 Subject: [PATCH 17/30] Remove extra info diag - everything should be in logs Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/newton_schulz/newton_schulz.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index fa1be76b20..a773aaea77 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -182,16 +182,13 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor std::vector workspace_host(wrksp_size_host); // Execute Newton-Schulz - int info = 0; NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( ctx->handle.get(), ns_desc.get(), m, n, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), - workspace_host.size(), &info)); + workspace_host.size(), nullptr)); // Make the caller's stream wait for the internal stream so that // the output tensor is ready before the caller uses it. NVTE_CHECK_CUDA(cudaEventRecord(ctx->out_ready, ctx->stream)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, ctx->out_ready)); - - NVTE_CHECK(info == 0, "cusolverMpNewtonSchulz failed with info = ", info); } From 7c8a65642dcf5851111321b7f01674cbfbf88b50 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 25 Feb 2026 22:01:07 +0000 Subject: [PATCH 18/30] Add Newton-Schulz tests to the QA script Signed-off-by: Vladimir Cherepanov --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..db13e9f1e0 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" # debug tests From 59e8afff604d4c4d922b7d6868a5132b408737c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 23:38:09 +0000 Subject: [PATCH 19/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/distributed/run_newton_schulz.py | 26 ++++++++++++++----- .../pytorch/distributed/test_newton_schulz.py | 2 +- .../transformer_engine/newton_schulz.h | 4 +-- .../common/newton_schulz/newton_schulz.cpp | 26 +++++++++---------- transformer_engine/common/util/logging.h | 13 +++++----- transformer_engine/pytorch/csrc/extensions.h | 4 +-- .../pytorch/csrc/extensions/newton_schulz.cpp | 4 +-- .../pytorch/csrc/extensions/pybind.cpp | 3 +-- transformer_engine/pytorch/newton_schulz.py | 4 +-- 9 files changed, 47 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index 223aa23602..a871006131 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -36,12 +36,26 @@ def main(): # Ensure N is divisible by world_size assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" - quintic_coefficients = [4.0848, -6.8946, 2.9270, - 3.9505, -6.3029, 2.6377, - 3.7418, -5.5913, 2.3037, - 2.8769, -3.1427, 1.2046, - 2.8366, -3.0525, 1.2012] - coefficients = quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations + quintic_coefficients = [ + 4.0848, + -6.8946, + 2.9270, + 3.9505, + -6.3029, + 2.6377, + 3.7418, + -5.5913, + 2.3037, + 2.8769, + -3.1427, + 1.2046, + 2.8366, + -3.0525, + 1.2012, + ] + coefficients = ( + quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations + ) # Create a random symmetric positive definite matrix on rank 0 # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index 95b1335b32..2e83a9084a 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -41,7 +41,7 @@ def test_newton_schulz(dtype, matrix_size, num_iterations): or "NUMERICAL CHECK PASSED" not in result.stdout.decode() ): raise AssertionError( - f"Newton-Schulz test failed.\n" + "Newton-Schulz test failed.\n" f"stdout: {result.stdout.decode()}\n" f"stderr: {result.stderr.decode()}" ) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h index 543c3f63ba..8ea4e7bd14 100644 --- a/transformer_engine/common/include/transformer_engine/newton_schulz.h +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -62,8 +62,8 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); * Used for event-based synchronisation with the internal stream. */ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, - int64_t num_iterations, const float* coefficients, - int64_t num_coefficients, cudaStream_t caller_stream); + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index a773aaea77..06b899a603 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -6,8 +6,8 @@ #include "transformer_engine/newton_schulz.h" -#include #include +#include #include #include @@ -23,14 +23,14 @@ using namespace transformer_engine; struct CusolverMpHandleDeleter { void operator()(cusolverMpHandle_t handle) const { cusolverMpDestroy(handle); } }; -using CusolverMpHandle = std::unique_ptr, - CusolverMpHandleDeleter>; +using CusolverMpHandle = + std::unique_ptr, CusolverMpHandleDeleter>; struct CusolverMpGridDeleter { void operator()(cusolverMpGrid_t grid) const { cusolverMpDestroyGrid(grid); } }; -using CusolverMpGrid = std::unique_ptr, - CusolverMpGridDeleter>; +using CusolverMpGrid = + std::unique_ptr, CusolverMpGridDeleter>; namespace { @@ -54,17 +54,16 @@ CusolverMpHandle MakeCusolverMpHandle(int device_id, cudaStream_t stream) { return CusolverMpHandle(raw); } -CusolverMpGrid MakeCusolverMpGrid(cusolverMpHandle_t handle, ncclComm_t comm, - int32_t nprow, int32_t npcol, - cusolverMpGridMapping_t mapping) { +CusolverMpGrid MakeCusolverMpGrid(cusolverMpHandle_t handle, ncclComm_t comm, int32_t nprow, + int32_t npcol, cusolverMpGridMapping_t mapping) { cusolverMpGrid_t raw{}; NVTE_CHECK_CUSOLVERMP(cusolverMpCreateDeviceGrid(handle, &raw, comm, nprow, npcol, mapping)); return CusolverMpGrid(raw); } CusolverMpMatrixDesc MakeCusolverMpMatrixDesc(cusolverMpGrid_t grid, cudaDataType_t dtype, - int64_t m, int64_t n, int64_t mb, int64_t nb, - uint32_t rsrc, uint32_t csrc, int64_t lld) { + int64_t m, int64_t n, int64_t mb, int64_t nb, + uint32_t rsrc, uint32_t csrc, int64_t lld) { cusolverMpMatrixDescriptor_t raw{}; NVTE_CHECK_CUSOLVERMP( cusolverMpCreateMatrixDesc(&raw, grid, dtype, m, n, mb, nb, rsrc, csrc, lld)); @@ -105,8 +104,7 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r NVTE_CHECK_CUDA(cudaEventCreate(&out_ready)); auto handle = MakeCusolverMpHandle(device_id, stream); - auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, - CUSOLVERMP_GRID_MAPPING_COL_MAJOR); + auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, CUSOLVERMP_GRID_MAPPING_COL_MAJOR); return new NVTECusolverMpCtx{ .nranks = nranks, @@ -136,8 +134,8 @@ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { } void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, - int64_t num_iterations, const float* coefficients, - int64_t num_coefficients, cudaStream_t caller_stream) { + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream) { NVTE_API_CALL(nvte_newton_schulz); const auto* t = convertNVTETensorCheck(x); diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index cbb55e5220..cf72cab048 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -112,13 +112,12 @@ #ifdef NVTE_WITH_CUSOLVERMP -#define NVTE_CHECK_CUSOLVERMP(expr) \ - do { \ - const cusolverStatus_t status_NVTE_CHECK_CUSOLVERMP = (expr); \ - if (status_NVTE_CHECK_CUSOLVERMP != CUSOLVER_STATUS_SUCCESS) { \ - NVTE_ERROR("cuSolverMp Error: ", \ - std::to_string(status_NVTE_CHECK_CUSOLVERMP)); \ - } \ +#define NVTE_CHECK_CUSOLVERMP(expr) \ + do { \ + const cusolverStatus_t status_NVTE_CHECK_CUSOLVERMP = (expr); \ + if (status_NVTE_CHECK_CUSOLVERMP != CUSOLVER_STATUS_SUCCESS) { \ + NVTE_ERROR("cuSolverMp Error: ", std::to_string(status_NVTE_CHECK_CUSOLVERMP)); \ + } \ } while (false) #endif // NVTE_WITH_CUSOLVERMP diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0e9d2c9318..be1b45d3f9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -512,8 +512,8 @@ int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); void cusolvermp_ctx_destroy(int64_t ctx_ptr); -void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, - int64_t num_iterations, std::vector coefficients); +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients); } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 34d9bb87f8..5026c79e07 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -32,8 +32,8 @@ void cusolvermp_ctx_destroy(int64_t ctx_ptr) { #endif } -void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, - int64_t num_iterations, std::vector coefficients) { +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients) { #ifdef NVTE_WITH_CUSOLVERMP auto* ctx = reinterpret_cast(ctx_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d393ed0d4a..9cb4f9c260 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -448,8 +448,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), py::arg("rank"), py::call_guard()); m.def("cusolvermp_ctx_destroy", &transformer_engine::pytorch::cusolvermp_ctx_destroy, - "Destroy cuSolverMp context", py::arg("ctx_ptr"), - py::call_guard()); + "Destroy cuSolverMp context", py::arg("ctx_ptr"), py::call_guard()); m.def("newton_schulz", &transformer_engine::pytorch::newton_schulz, "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index f338003168..953e519d75 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -27,9 +27,7 @@ def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: """Extract the raw NCCL communicator pointer from a PyTorch process group.""" backend = dist.get_backend(group) if backend != "nccl": - raise RuntimeError( - f"newton_schulz requires NCCL backend, got '{backend}'" - ) + raise RuntimeError(f"newton_schulz requires NCCL backend, got '{backend}'") nccl_backend = group._get_backend(torch.device("cuda")) return nccl_backend._comm_ptr() From e433f06f04f2c8aea9d1daa3ba4df5a145097076 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 25 Feb 2026 15:47:53 -0800 Subject: [PATCH 20/30] Fix outdated comments Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_newton_schulz.py | 2 +- tests/pytorch/distributed/test_newton_schulz.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index a871006131..99ad7ea6ff 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -90,7 +90,7 @@ def main(): dist.all_gather(gathered, x_local) X = torch.cat(gathered, dim=0) - # Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix + # Check: the resulting matrix should be orthogonal if rank == 0: XXT = X @ X.t() I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index 2e83a9084a..f79ab3bc33 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("matrix_size", [256]) @pytest.mark.parametrize("num_iterations", [5, 15]) def test_newton_schulz(dtype, matrix_size, num_iterations): - """Test distributed Newton-Schulz inverse square root.""" + """Test distributed Newton-Schulz matrix orthogonalization.""" test_path = TEST_ROOT / "run_newton_schulz.py" test_cmd = LAUNCH_CMD + [ str(test_path), From 276b8417092cce05aa32318017db2a1e7221477b Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 25 Feb 2026 15:52:43 -0800 Subject: [PATCH 21/30] Remove unused variable Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_newton_schulz.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index 99ad7ea6ff..5d10674e64 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -72,9 +72,6 @@ def main(): # Broadcast the full matrix to all ranks dist.broadcast(A, src=0) - # Keep a copy of the original matrix for verification - A_orig = A.clone() - # Scatter rows to each rank local_rows = N // world_size x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous() From 7fad8944efddba19de96c815635cdf5db5342522 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 26 Feb 2026 00:45:36 +0000 Subject: [PATCH 22/30] Move magic numbers from tests to impl Signed-off-by: Vladimir Cherepanov --- .../pytorch/distributed/run_newton_schulz.py | 23 +------------ .../pytorch/distributed/test_newton_schulz.py | 1 - transformer_engine/pytorch/newton_schulz.py | 32 +++++++++++-------- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py index 5d10674e64..96663bda43 100644 --- a/tests/pytorch/distributed/run_newton_schulz.py +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -36,27 +36,6 @@ def main(): # Ensure N is divisible by world_size assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" - quintic_coefficients = [ - 4.0848, - -6.8946, - 2.9270, - 3.9505, - -6.3029, - 2.6377, - 3.7418, - -5.5913, - 2.3037, - 2.8769, - -3.1427, - 1.2046, - 2.8366, - -3.0525, - 1.2012, - ] - coefficients = ( - quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations - ) - # Create a random symmetric positive definite matrix on rank 0 # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) # This ensures Newton-Schulz converges @@ -80,7 +59,7 @@ def main(): from transformer_engine.pytorch.newton_schulz import newton_schulz group = dist.group.WORLD - newton_schulz(x_local, group, args.num_iterations, coefficients) + newton_schulz(x_local, group, args.num_iterations) # Gather results gathered = [torch.empty_like(x_local) for _ in range(world_size)] diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index f79ab3bc33..f2a0966394 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -18,7 +18,6 @@ NUM_PROCS = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] - @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) @pytest.mark.parametrize("matrix_size", [256]) @pytest.mark.parametrize("num_iterations", [5, 15]) diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index 953e519d75..f529dd2fdd 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -12,17 +12,6 @@ import transformer_engine_torch as tex -# Default quintic polynomial coefficients for 5-iteration Newton-Schulz -# from cuSolverMp sample: (3069/1024, -7175/1024, 9009/1024, -6435/1024, 2835/2048) -_DEFAULT_COEFFICIENTS = [ - 3069.0 / 1024.0, - -7175.0 / 1024.0, - 9009.0 / 1024.0, - -6435.0 / 1024.0, - 2835.0 / 2048.0, -] - - def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: """Extract the raw NCCL communicator pointer from a PyTorch process group.""" backend = dist.get_backend(group) @@ -51,10 +40,27 @@ def newton_schulz( Number of Newton-Schulz iterations. Default: 5. coefficients : list of float, optional Polynomial coefficients for the Newton-Schulz iteration. - Default: quintic polynomial coefficients from cuSolverMp sample. """ + QUINTIC_COEFFICIENTS = [ + 4.0848, + -6.8946, + 2.9270, + 3.9505, + -6.3029, + 2.6377, + 3.7418, + -5.5913, + 2.3037, + 2.8769, + -3.1427, + 1.2046, + 2.8366, + -3.0525, + 1.2012, + ] if coefficients is None: - coefficients = _DEFAULT_COEFFICIENTS + coefficients = QUINTIC_COEFFICIENTS if num_iterations==5 else [1.5, -0.5, 0.0] * num_iterations + assert len(coefficients) == num_iterations * 3, f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" if x.dim() != 2: raise ValueError(f"Expected 2D tensor, got {x.dim()}D") From 1e726ce8c3207a01073de0106b6d4e10f1759389 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 00:46:45 +0000 Subject: [PATCH 23/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/distributed/test_newton_schulz.py | 1 + transformer_engine/pytorch/newton_schulz.py | 38 ++++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index f2a0966394..f79ab3bc33 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -18,6 +18,7 @@ NUM_PROCS = torch.cuda.device_count() LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) @pytest.mark.parametrize("matrix_size", [256]) @pytest.mark.parametrize("num_iterations", [5, 15]) diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index f529dd2fdd..84b3073eac 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -42,25 +42,29 @@ def newton_schulz( Polynomial coefficients for the Newton-Schulz iteration. """ QUINTIC_COEFFICIENTS = [ - 4.0848, - -6.8946, - 2.9270, - 3.9505, - -6.3029, - 2.6377, - 3.7418, - -5.5913, - 2.3037, - 2.8769, - -3.1427, - 1.2046, - 2.8366, - -3.0525, - 1.2012, + 4.0848, + -6.8946, + 2.9270, + 3.9505, + -6.3029, + 2.6377, + 3.7418, + -5.5913, + 2.3037, + 2.8769, + -3.1427, + 1.2046, + 2.8366, + -3.0525, + 1.2012, ] if coefficients is None: - coefficients = QUINTIC_COEFFICIENTS if num_iterations==5 else [1.5, -0.5, 0.0] * num_iterations - assert len(coefficients) == num_iterations * 3, f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" + coefficients = ( + QUINTIC_COEFFICIENTS if num_iterations == 5 else [1.5, -0.5, 0.0] * num_iterations + ) + assert ( + len(coefficients) == num_iterations * 3 + ), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations" if x.dim() != 2: raise ValueError(f"Expected 2D tensor, got {x.dim()}D") From fac55db87ae0b404677446b4beea7cb6aac6552c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 26 Feb 2026 01:04:49 +0000 Subject: [PATCH 24/30] Fix outdated comments Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/test_newton_schulz.py | 2 +- .../common/include/transformer_engine/newton_schulz.h | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py index f79ab3bc33..5646d11dd9 100644 --- a/tests/pytorch/distributed/test_newton_schulz.py +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Tests for distributed Newton-Schulz inverse square root.""" +"""Tests for distributed Newton-Schulz matrix orthogonalization.""" import os import subprocess diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h index 8ea4e7bd14..b604b2db39 100644 --- a/transformer_engine/common/include/transformer_engine/newton_schulz.h +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -5,11 +5,10 @@ ************************************************************************/ /*! \file newton_schulz.h - * \brief Functions for distributed Newton-Schulz inverse square root. + * \brief Functions for distributed Newton-Schulz matrix orthogonalization. * * This API is a TE-native binding to the cuSolverMp library. - * It computes an iterative Newton-Schulz inverse square root - * approximation on a distributed matrix. + * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. */ #ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ @@ -45,10 +44,7 @@ NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int r */ void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); -/*! \brief Compute Newton-Schulz inverse square root in-place. - * - * Performs iterative Newton-Schulz approximation of the inverse square root - * on a distributed matrix using cuSolverMp. +/*! \brief Compute Newton-Schulz matrix orthogonalization in-place. * * \param[in] ctx cuSolverMp context. * \param[in] m Global number of rows. From 0732fc2df90755508979a60f1ad55d6f4ba46341 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 26 Feb 2026 01:27:05 +0000 Subject: [PATCH 25/30] Check num_coefficients Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/newton_schulz/newton_schulz.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 06b899a603..1095992a63 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -137,6 +137,7 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor int64_t num_iterations, const float* coefficients, int64_t num_coefficients, cudaStream_t caller_stream) { NVTE_API_CALL(nvte_newton_schulz); + NVTE_CHECK(num_coefficients == num_iterations * 3, num_iterations, " iterations require ", num_iterations * 3, " coefficients, but ", num_coefficients, " are passed"); const auto* t = convertNVTETensorCheck(x); // Make the internal stream wait for the caller's stream so that From 8eb6028d19e8ea99358fc97065fbe7e9ae6f84cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 01:28:32 +0000 Subject: [PATCH 26/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/newton_schulz/newton_schulz.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp index 1095992a63..bf9a92aac8 100644 --- a/transformer_engine/common/newton_schulz/newton_schulz.cpp +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -137,7 +137,8 @@ void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor int64_t num_iterations, const float* coefficients, int64_t num_coefficients, cudaStream_t caller_stream) { NVTE_API_CALL(nvte_newton_schulz); - NVTE_CHECK(num_coefficients == num_iterations * 3, num_iterations, " iterations require ", num_iterations * 3, " coefficients, but ", num_coefficients, " are passed"); + NVTE_CHECK(num_coefficients == num_iterations * 3, num_iterations, " iterations require ", + num_iterations * 3, " coefficients, but ", num_coefficients, " are passed"); const auto* t = convertNVTETensorCheck(x); // Make the internal stream wait for the caller's stream so that From 8f50bd59d198775b91c2b645f9486398f621f368 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 27 Feb 2026 03:01:17 +0000 Subject: [PATCH 27/30] Auto-detect cuSolverMp support from common library binary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of requiring NVTE_WITH_CUSOLVERMP env var to be set for both the common library and PyTorch extension builds, inspect the already-built libtransformer_engine.so for exported symbols. This is more robust for incremental builds and CI environments where the env var may not be propagated to the extension build step. The PyTorch extension only calls nvte_* C API functions, so it does not need cusolverMp headers or libraries — only the compile definition. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- build_tools/pytorch.py | 18 ++++++++++++------ build_tools/utils.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 4530f1c4d7..7985df508c 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -8,7 +8,13 @@ import setuptools -from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled +from .utils import ( + all_files_in_dir, + common_lib_has_symbol, + cuda_version, + get_cuda_include_dirs, + debug_build_enabled, +) from typing import List @@ -87,11 +93,11 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") - if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): - cusolvermp_home = Path(os.getenv("CUSOLVERMP_HOME", "/usr")) - include_dirs.append(cusolvermp_home / "include") - library_dirs.append(cusolvermp_home / "lib") - libraries.append("cusolverMp") + # Auto-detect cuSolverMp support from the already-built common library. + # The PyTorch extension only calls nvte_* C API functions (not cusolverMp + # directly), so it only needs the compile definition, not cusolverMp + # headers or libraries. + if common_lib_has_symbol("nvte_cusolvermp_ctx_create"): cxx_flags.append("-DNVTE_WITH_CUSOLVERMP") # Construct PyTorch CUDA extension diff --git a/build_tools/utils.py b/build_tools/utils.py index 885901068a..6d817c3315 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -374,3 +374,37 @@ def copy_common_headers( new_path = dst_dir / path.relative_to(src_dir) new_path.parent.mkdir(exist_ok=True, parents=True) shutil.copy(path, new_path) + + +def common_lib_has_symbol(symbol: str) -> bool: + """Check if the built libtransformer_engine.so exports a given symbol. + + Searches for the library in known build/install locations and uses + ``nm -D --defined-only`` to inspect the dynamic symbol table. + """ + root = Path(__file__).resolve().parent.parent + + # Candidate paths: editable-install root, default CMake build dir, + # and user-specified CMake build dir. + candidates = [ + root / "libtransformer_engine.so", + root / "build" / "cmake" / "libtransformer_engine.so", + ] + custom_build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR") + if custom_build_dir: + candidates.append(Path(custom_build_dir) / "libtransformer_engine.so") + + for lib_path in candidates: + if not lib_path.is_file(): + continue + try: + result = subprocess.run( + ["nm", "-D", "--defined-only", str(lib_path)], + capture_output=True, + text=True, + check=True, + ) + return symbol in result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + continue + return False From bb991812bd81a339ff91b521fac428333536d1fb Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 27 Feb 2026 03:06:28 +0000 Subject: [PATCH 28/30] Conditionally exclude Newton-Schulz API from PyTorch extension When NVTE_WITH_CUSOLVERMP is not defined, omit the Newton-Schulz functions entirely from the pybind module instead of registering stubs that throw runtime errors. The Python wrapper checks for the attribute at call time and raises a clear error message. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Vladimir Cherepanov --- transformer_engine/pytorch/csrc/extensions.h | 4 ++++ .../pytorch/csrc/extensions/newton_schulz.cpp | 19 ++++--------------- .../pytorch/csrc/extensions/pybind.cpp | 2 ++ transformer_engine/pytorch/newton_schulz.py | 5 +++++ 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index be1b45d3f9..27585bf0c7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -508,6 +508,8 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at: * Newton-Schulz (cuSolverMp) **************************************************************************************************/ +#ifdef NVTE_WITH_CUSOLVERMP + int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); void cusolvermp_ctx_destroy(int64_t ctx_ptr); @@ -515,6 +517,8 @@ void cusolvermp_ctx_destroy(int64_t ctx_ptr); void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, std::vector coefficients); +#endif // NVTE_WITH_CUSOLVERMP + } // namespace transformer_engine::pytorch /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 5026c79e07..2a16696cee 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -4,37 +4,27 @@ * See LICENSE for license information. ************************************************************************/ +#ifdef NVTE_WITH_CUSOLVERMP + #include "../extensions.h" -#ifdef NVTE_WITH_CUSOLVERMP #include "transformer_engine/newton_schulz.h" -#endif namespace transformer_engine::pytorch { int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { -#ifdef NVTE_WITH_CUSOLVERMP auto comm = reinterpret_cast(nccl_comm_ptr); auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); return reinterpret_cast(ctx); -#else - NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); - return 0; -#endif } void cusolvermp_ctx_destroy(int64_t ctx_ptr) { -#ifdef NVTE_WITH_CUSOLVERMP auto* ctx = reinterpret_cast(ctx_ptr); nvte_cusolvermp_ctx_destroy(ctx); -#else - NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); -#endif } void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, std::vector coefficients) { -#ifdef NVTE_WITH_CUSOLVERMP auto* ctx = reinterpret_cast(ctx_ptr); // Build NVTETensor from PyTorch tensor @@ -47,9 +37,8 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t auto caller_stream = at::cuda::getCurrentCUDAStream().stream(); nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), static_cast(coefficients.size()), caller_stream); -#else - NVTE_ERROR("newton_schulz requires building with NVTE_WITH_CUSOLVERMP=1"); -#endif } } // namespace transformer_engine::pytorch + +#endif // NVTE_WITH_CUSOLVERMP diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 9cb4f9c260..71fecf5651 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -443,6 +443,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, "Fused compute E8M0 scale_inv from amax", py::call_guard()); +#ifdef NVTE_WITH_CUSOLVERMP // Newton-Schulz (cuSolverMp) m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create, "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), @@ -453,6 +454,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Newton-Schulz inverse square root", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), py::call_guard()); +#endif // NVTE_WITH_CUSOLVERMP // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py index 84b3073eac..3e4c844bdc 100644 --- a/transformer_engine/pytorch/newton_schulz.py +++ b/transformer_engine/pytorch/newton_schulz.py @@ -71,6 +71,11 @@ def newton_schulz( if not x.is_cuda: raise ValueError("Input tensor must be on CUDA device") + if not hasattr(tex, "newton_schulz"): + raise RuntimeError( + "newton_schulz requires Transformer Engine to be built with NVTE_WITH_CUSOLVERMP=1" + ) + nccl_comm_ptr = _get_nccl_comm_ptr(group) nranks = dist.get_world_size(group) rank = dist.get_rank(group) From d3740fb00176e521317b20ffb22118b4e1885cba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 03:11:02 +0000 Subject: [PATCH 29/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp index 2a16696cee..3615d5be31 100644 --- a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -6,10 +6,10 @@ #ifdef NVTE_WITH_CUSOLVERMP -#include "../extensions.h" - #include "transformer_engine/newton_schulz.h" +#include "../extensions.h" + namespace transformer_engine::pytorch { int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { From f17b65767092dda10d16966fead7be60dd33e224 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 27 Feb 2026 12:49:31 -0800 Subject: [PATCH 30/30] Make symbol detection errors fatal in common_lib_has_symbol Raise FileNotFoundError when no libtransformer_engine.so is found in any candidate location, and raise RuntimeError when nm is unavailable or exits non-zero, rather than silently returning False in both cases. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Vladimir Cherepanov --- build_tools/utils.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index 6d817c3315..3fd537db3d 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -394,17 +394,30 @@ def common_lib_has_symbol(symbol: str) -> bool: if custom_build_dir: candidates.append(Path(custom_build_dir) / "libtransformer_engine.so") - for lib_path in candidates: - if not lib_path.is_file(): - continue - try: - result = subprocess.run( - ["nm", "-D", "--defined-only", str(lib_path)], - capture_output=True, - text=True, - check=True, - ) - return symbol in result.stdout - except (subprocess.CalledProcessError, FileNotFoundError): - continue - return False + lib_path = None + for candidate in candidates: + if candidate.is_file(): + lib_path = candidate + break + + if lib_path is None: + raise FileNotFoundError( + "Could not find libtransformer_engine.so in any of the expected locations: " + + ", ".join(str(c) for c in candidates) + ) + + try: + result = subprocess.run( + ["nm", "-D", "--defined-only", str(lib_path)], + capture_output=True, + text=True, + check=True, + ) + except FileNotFoundError as e: + raise RuntimeError("'nm' is not available on this system.") from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"'nm' failed on {lib_path} (exit code {e.returncode}):\n{e.stderr}" + ) from e + + return symbol in result.stdout