-
Notifications
You must be signed in to change notification settings - Fork 650
[Draft] Newton-Schulz via cuSOLVERMp #2706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
948037e
efefa7e
02299b3
ed6c21f
fbc1c4e
f8b23cc
8dbdcbb
1c01a9d
1d8115d
fcb4d33
b4422b9
e8c51f8
412445c
9645073
dd1dd0b
85d33fb
a011231
7c8a656
59e8aff
e433f06
276b841
7fad894
1e726ce
fac55db
0732fc2
8eb6028
8f50bd5
bb99181
d3740fb
f17b657
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Distributed Newton-Schulz test worker. | ||
|
|
||
| Launched via torchrun from test_newton_schulz.py. | ||
| """ | ||
|
|
||
| import argparse | ||
| import sys | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.elastic.multiprocessing.errors import record | ||
|
|
||
|
|
||
| @record | ||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") | ||
| parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) | ||
| parser.add_argument("--matrix-size", type=int, default=256) | ||
| parser.add_argument("--num-iterations", type=int, default=5) | ||
| parser.add_argument("--atol", type=float, default=1e-2) | ||
| parser.add_argument("--rtol", type=float, default=1e-2) | ||
| args = parser.parse_args() | ||
|
|
||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 | ||
| N = args.matrix_size | ||
|
|
||
| # Ensure N is divisible by world_size | ||
| assert N % world_size == 0, f"Matrix size {N} must be divisible by world_size {world_size}" | ||
|
|
||
| # Create a random symmetric positive definite matrix on rank 0 | ||
| # A = Q @ diag(eigenvalues) @ Q^T with eigenvalues in (0, 1) | ||
| # This ensures Newton-Schulz converges | ||
| if rank == 0: | ||
| torch.manual_seed(42) | ||
| Q, _ = torch.linalg.qr(torch.randn(N, N, device="cuda", dtype=torch.float32)) | ||
| eigenvalues = torch.rand(N, device="cuda", dtype=torch.float32) * 0.8 + 0.1 | ||
| A = Q @ torch.diag(eigenvalues) @ Q.T | ||
| A = A.to(dtype) | ||
| else: | ||
| A = torch.empty(N, N, device="cuda", dtype=dtype) | ||
|
|
||
| # Broadcast the full matrix to all ranks | ||
| dist.broadcast(A, src=0) | ||
|
|
||
| # Scatter rows to each rank | ||
| local_rows = N // world_size | ||
| x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous() | ||
|
|
||
| # Run the distributed Newton-Schulz | ||
| from transformer_engine.pytorch.newton_schulz import newton_schulz | ||
|
|
||
| group = dist.group.WORLD | ||
| newton_schulz(x_local, group, args.num_iterations) | ||
|
|
||
| # Gather results | ||
| gathered = [torch.empty_like(x_local) for _ in range(world_size)] | ||
| dist.all_gather(gathered, x_local) | ||
| X = torch.cat(gathered, dim=0) | ||
|
|
||
| # Check: the resulting matrix should be orthogonal | ||
| if rank == 0: | ||
| XXT = X @ X.t() | ||
| I = torch.eye(N, device=XXT.device, dtype=XXT.dtype) | ||
| max_diff = (XXT - I).abs().max().item() | ||
| print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) | ||
|
|
||
| if torch.allclose(XXT, I, atol=args.atol, rtol=args.rtol): | ||
| print("NUMERICAL CHECK PASSED", flush=True) | ||
| else: | ||
| print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Tests for distributed Newton-Schulz matrix orthogonalization.""" | ||
|
|
||
| import os | ||
| import subprocess | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| if torch.cuda.device_count() < 2: | ||
| pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) | ||
|
|
||
| TEST_ROOT = Path(__file__).parent.resolve() | ||
| NUM_PROCS = torch.cuda.device_count() | ||
| LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("matrix_size", [256]) | ||
| @pytest.mark.parametrize("num_iterations", [5, 15]) | ||
| def test_newton_schulz(dtype, matrix_size, num_iterations): | ||
| """Test distributed Newton-Schulz matrix orthogonalization.""" | ||
| test_path = TEST_ROOT / "run_newton_schulz.py" | ||
| test_cmd = LAUNCH_CMD + [ | ||
| str(test_path), | ||
| f"--dtype={dtype}", | ||
| f"--matrix-size={matrix_size}", | ||
| f"--num-iterations={num_iterations}", | ||
| ] | ||
| if dtype == "bfloat16": | ||
| test_cmd += ["--atol=5e-2", "--rtol=5e-2"] | ||
|
|
||
| result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) | ||
| if ( | ||
| result.returncode != 0 | ||
| or "NUMERICAL CHECK FAILED" in result.stderr.decode() | ||
| or "NUMERICAL CHECK PASSED" not in result.stdout.decode() | ||
| ): | ||
| raise AssertionError( | ||
| "Newton-Schulz test failed.\n" | ||
| f"stdout: {result.stdout.decode()}\n" | ||
| f"stderr: {result.stderr.decode()}" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -227,6 +227,11 @@ list(APPEND transformer_engine_SOURCES | |
| comm_gemm/comm_gemm.cpp) | ||
| endif() | ||
|
|
||
| if (NVTE_WITH_CUSOLVERMP) | ||
| list(APPEND transformer_engine_SOURCES | ||
| newton_schulz/newton_schulz.cpp) | ||
| endif() | ||
|
|
||
| add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) | ||
| target_include_directories(transformer_engine PUBLIC | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
|
@@ -300,6 +305,19 @@ if (NVTE_WITH_CUBLASMP) | |
| message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") | ||
| endif() | ||
|
|
||
| option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF) | ||
| if (NVTE_WITH_CUSOLVERMP) | ||
| target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP) | ||
| target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include) | ||
| find_library(CUSOLVERMP_LIB | ||
| NAMES cusolverMp libcusolverMp | ||
| PATHS ${CUSOLVERMP_DIR} | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
| message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") | ||
| endif() | ||
|
|
||
| # Hack to enable dynamic loading in cuDNN frontend | ||
| target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file newton_schulz.h | ||
| * \brief Functions for distributed Newton-Schulz matrix orthogonalization. | ||
| * | ||
| * This API is a TE-native binding to the cuSolverMp library. | ||
| * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||
| #define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||
|
|
||
| #include <nccl.h> | ||
| #include <stdint.h> | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #else | ||
| #include <stdbool.h> | ||
| #endif | ||
|
|
||
| typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; | ||
|
|
||
| /*! \brief Create a cuSolverMp context for Newton-Schulz operations. | ||
| * | ||
| * Creates a dedicated CUDA stream internally (cuSolverMp requires a | ||
| * non-default stream). | ||
| * | ||
| * \param[in] comm NCCL communicator. | ||
| * \param[in] nranks Number of ranks. | ||
| * \param[in] rank Local rank. | ||
| */ | ||
| NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); | ||
|
|
||
| /*! \brief Destroy a cuSolverMp context. | ||
| * | ||
| * \param[in] ctx Context to destroy. | ||
| */ | ||
| void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); | ||
|
|
||
| /*! \brief Compute Newton-Schulz matrix orthogonalization in-place. | ||
| * | ||
| * \param[in] ctx cuSolverMp context. | ||
| * \param[in] m Global number of rows. | ||
| * \param[in] n Global number of columns. | ||
| * \param[in,out] x Local part of the matrix (modified in-place). | ||
| * \param[in] num_iterations Number of Newton-Schulz iterations. | ||
| * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial | ||
| * degree used internally by cuSolverMp). | ||
| * \param[in] num_coefficients Number of elements in the coefficients array. | ||
| * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. | ||
| * Used for event-based synchronisation with the internal stream. | ||
| */ | ||
| void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, | ||
| int64_t num_iterations, const float* coefficients, int64_t num_coefficients, | ||
| cudaStream_t caller_stream); | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no timeout on subprocess - if the distributed test deadlocks or hangs (e.g., due to NCCL communication issues), this will block CI indefinitely. Add
timeout=300or similar.