Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
948037e
[Common] Add Newton-Schulz inverse square root C API via cuSolverMp
vcherepanov-nv Feb 8, 2026
efefa7e
[PyTorch] Add Newton-Schulz PyTorch bindings and distributed tests
vcherepanov-nv Feb 8, 2026
02299b3
[Common] Fix cuSolverMp API signatures in Newton-Schulz implementation
vcherepanov-nv Feb 8, 2026
ed6c21f
[PyTorch] Propagate NVTE_WITH_CUSOLVERMP define to PyTorch extension …
vcherepanov-nv Feb 8, 2026
fbc1c4e
[PyTorch] Fix NCCL comm extraction and pass global dims to Newton-Schulz
vcherepanov-nv Feb 9, 2026
f8b23cc
[Common] Cache cuSolverMp handle and grid in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
8dbdcbb
[Common] Create dedicated CUDA stream in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
1c01a9d
[Common] Fix Newton-Schulz zero output with event-based stream sync
vcherepanov-nv Feb 18, 2026
1d8115d
[Common] Fix Newton-Schulz NaNs by keeping host workspace alive
vcherepanov-nv Feb 18, 2026
fcb4d33
[Common] Cache CUDA event in Newton-Schulz context
vcherepanov-nv Feb 18, 2026
b4422b9
[Common] Use separate in/out events for Newton-Schulz stream sync
vcherepanov-nv Feb 18, 2026
e8c51f8
Correct coefficients
vcherepanov-nv Feb 18, 2026
412445c
No stream synchronize
vcherepanov-nv Feb 18, 2026
9645073
[Test] Verify Newton-Schulz result with XAX=I identity check
vcherepanov-nv Feb 18, 2026
dd1dd0b
Change test - it approximates orthogonal matrix, not inverse square root
vcherepanov-nv Feb 19, 2026
85d33fb
Generalize number of iterations in tests
vcherepanov-nv Feb 19, 2026
a011231
Remove extra info diag - everything should be in logs
vcherepanov-nv Feb 25, 2026
7c8a656
Add Newton-Schulz tests to the QA script
vcherepanov-nv Feb 25, 2026
59e8aff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2026
e433f06
Fix outdated comments
vcherepanov-nv Feb 25, 2026
276b841
Remove unused variable
vcherepanov-nv Feb 25, 2026
7fad894
Move magic numbers from tests to impl
vcherepanov-nv Feb 26, 2026
1e726ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
fac55db
Fix outdated comments
vcherepanov-nv Feb 26, 2026
0732fc2
Check num_coefficients
vcherepanov-nv Feb 26, 2026
8eb6028
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
8f50bd5
Auto-detect cuSolverMp support from common library binary
vcherepanov-nv Feb 27, 2026
bb99181
Conditionally exclude Newton-Schulz API from PyTorch extension
vcherepanov-nv Feb 27, 2026
d3740fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
f17b657
Make symbol detection errors fatal in common_lib_has_symbol
vcherepanov-nv Feb 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import setuptools

from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
from .utils import (
all_files_in_dir,
common_lib_has_symbol,
cuda_version,
get_cuda_include_dirs,
debug_build_enabled,
)
from typing import List


Expand Down Expand Up @@ -87,6 +93,13 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

# Auto-detect cuSolverMp support from the already-built common library.
# The PyTorch extension only calls nvte_* C API functions (not cusolverMp
# directly), so it only needs the compile definition, not cusolverMp
# headers or libraries.
if common_lib_has_symbol("nvte_cusolvermp_ctx_create"):
cxx_flags.append("-DNVTE_WITH_CUSOLVERMP")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
47 changes: 47 additions & 0 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,50 @@ def copy_common_headers(
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)


def common_lib_has_symbol(symbol: str) -> bool:
"""Check if the built libtransformer_engine.so exports a given symbol.

Searches for the library in known build/install locations and uses
``nm -D --defined-only`` to inspect the dynamic symbol table.
"""
root = Path(__file__).resolve().parent.parent

# Candidate paths: editable-install root, default CMake build dir,
# and user-specified CMake build dir.
candidates = [
root / "libtransformer_engine.so",
root / "build" / "cmake" / "libtransformer_engine.so",
]
custom_build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if custom_build_dir:
candidates.append(Path(custom_build_dir) / "libtransformer_engine.so")

lib_path = None
for candidate in candidates:
if candidate.is_file():
lib_path = candidate
break

if lib_path is None:
raise FileNotFoundError(
"Could not find libtransformer_engine.so in any of the expected locations: "
+ ", ".join(str(c) for c in candidates)
)

try:
result = subprocess.run(
["nm", "-D", "--defined-only", str(lib_path)],
capture_output=True,
text=True,
check=True,
)
except FileNotFoundError as e:
raise RuntimeError("'nm' is not available on this system.") from e
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"'nm' failed on {lib_path} (exit code {e.returncode}):\n{e.stderr}"
) from e

return symbol in result.stdout
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"


# debug tests
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])

if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))):
cmake_flags.append("-DNVTE_WITH_CUSOLVERMP=ON")
cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr")
cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}")

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
Expand Down
86 changes: 86 additions & 0 deletions tests/pytorch/distributed/run_newton_schulz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Distributed Newton-Schulz test worker.

Launched via torchrun from test_newton_schulz.py.
"""

import argparse
import sys

import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record


@record
def main():
parser = argparse.ArgumentParser(description="Newton-Schulz distributed test")
parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"])
parser.add_argument("--matrix-size", type=int, default=256)
parser.add_argument("--num-iterations", type=int, default=5)
parser.add_argument("--atol", type=float, default=1e-2)
parser.add_argument("--rtol", type=float, default=1e-2)
args = parser.parse_args()

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)

dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
N = args.matrix_size

# Ensure N is divisible by world_size
assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}"

# Create a random symmetric positive definite matrix on rank 0
# A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1)
# This ensures Newton-Schulz converges
if rank == 0:
torch.manual_seed(42)
Q, _ = torch.linalg.qr(torch.randn(N, N, device="cuda", dtype=torch.float32))
eigenvalues = torch.rand(N, device="cuda", dtype=torch.float32) * 0.8 + 0.1
A = Q @ torch.diag(eigenvalues) @ Q.T
A = A.to(dtype)
else:
A = torch.empty(N, N, device="cuda", dtype=dtype)

# Broadcast the full matrix to all ranks
dist.broadcast(A, src=0)

# Scatter rows to each rank
local_rows = N // world_size
x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous()

# Run the distributed Newton-Schulz
from transformer_engine.pytorch.newton_schulz import newton_schulz

group = dist.group.WORLD
newton_schulz(x_local, group, args.num_iterations)

# Gather results
gathered = [torch.empty_like(x_local) for _ in range(world_size)]
dist.all_gather(gathered, x_local)
X = torch.cat(gathered, dim=0)

# Check: the resulting matrix should be orthogonal
if rank == 0:
XXT = X @ X.t()
I = torch.eye(N, device=XXT.device, dtype=XXT.dtype)
max_diff = (XXT - I).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)

if torch.allclose(XXT, I, atol=args.atol, rtol=args.rtol):
print("NUMERICAL CHECK PASSED", flush=True)
else:
print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr)
sys.exit(1)

dist.destroy_process_group()


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions tests/pytorch/distributed/test_newton_schulz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Tests for distributed Newton-Schulz matrix orthogonalization."""

import os
import subprocess
from pathlib import Path

import pytest
import torch

if torch.cuda.device_count() < 2:
pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True)

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
@pytest.mark.parametrize("matrix_size", [256])
@pytest.mark.parametrize("num_iterations", [5, 15])
def test_newton_schulz(dtype, matrix_size, num_iterations):
"""Test distributed Newton-Schulz matrix orthogonalization."""
test_path = TEST_ROOT / "run_newton_schulz.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--dtype={dtype}",
f"--matrix-size={matrix_size}",
f"--num-iterations={num_iterations}",
]
if dtype == "bfloat16":
test_cmd += ["--atol=5e-2", "--rtol=5e-2"]

result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no timeout on subprocess - if the distributed test deadlocks or hangs (e.g., due to NCCL communication issues), this will block CI indefinitely. Add timeout=300 or similar.

if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(
"Newton-Schulz test failed.\n"
f"stdout: {result.stdout.decode()}\n"
f"stderr: {result.stderr.decode()}"
)
18 changes: 18 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ list(APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp)
endif()

if (NVTE_WITH_CUSOLVERMP)
list(APPEND transformer_engine_SOURCES
newton_schulz/newton_schulz.cpp)
endif()

add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
Expand Down Expand Up @@ -300,6 +305,19 @@ if (NVTE_WITH_CUBLASMP)
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()

option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF)
if (NVTE_WITH_CUSOLVERMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP)
target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include)
find_library(CUSOLVERMP_LIB
NAMES cusolverMp libcusolverMp
PATHS ${CUSOLVERMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PUBLIC linkage exposes cuSOLVERMp to all downstream consumers of transformer_engine library. Since newton_schulz.h doesn't expose cuSOLVERMp types in the public API, PRIVATE linkage would provide better encapsulation (consumers don't need cuSOLVERMp at link time).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}")
endif()

# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file newton_schulz.h
* \brief Functions for distributed Newton-Schulz matrix orthogonalization.
*
* This API is a TE-native binding to the cuSolverMp library.
* It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix.
*/

#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_

#include <nccl.h>
#include <stdint.h>

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#else
#include <stdbool.h>
#endif

typedef struct NVTECusolverMpCtx NVTECusolverMpCtx;

/*! \brief Create a cuSolverMp context for Newton-Schulz operations.
*
* Creates a dedicated CUDA stream internally (cuSolverMp requires a
* non-default stream).
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank);

/*! \brief Destroy a cuSolverMp context.
*
* \param[in] ctx Context to destroy.
*/
void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx);

/*! \brief Compute Newton-Schulz matrix orthogonalization in-place.
*
* \param[in] ctx cuSolverMp context.
* \param[in] m Global number of rows.
* \param[in] n Global number of columns.
* \param[in,out] x Local part of the matrix (modified in-place).
* \param[in] num_iterations Number of Newton-Schulz iterations.
* \param[in] coefficients Array of polynomial coefficients (length depends on polynomial
* degree used internally by cuSolverMp).
* \param[in] num_coefficients Number of elements in the coefficients array.
* \param[in] caller_stream CUDA stream on which the caller produced the input tensor.
* Used for event-based synchronisation with the internal stream.
*/
void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x,
int64_t num_iterations, const float* coefficients, int64_t num_coefficients,
cudaStream_t caller_stream);

#ifdef __cplusplus
} // extern "C"
#endif

#endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
Loading
Loading