diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b9f8128..fcd19e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,5 +18,5 @@ jobs: # Run go vet on all packages except those with intentional # unsafe.Pointer usage for GPU runtime bindings via purego/dlopen. # These warnings are expected and documented in docs/QUALITY.md. - go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$') + go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$') - run: go test -race -timeout 300s ./... diff --git a/docs/adr/002-nccl-purego.md b/docs/adr/002-nccl-purego.md new file mode 100644 index 0000000..ec56661 --- /dev/null +++ b/docs/adr/002-nccl-purego.md @@ -0,0 +1,81 @@ +# ADR-002: Migrate internal/nccl from CGo to purego/dlopen + +**Status:** Accepted +**Date:** 2026-04-09 +**Authors:** David Ndungu + +## Context + +The original `internal/nccl` package was a CGo binding (`#include `, +`-lnccl`) gated behind `//go:build cuda`. This forced any build that wanted +NCCL to link against the system NCCL headers/library at compile time and +exposed the package only when the `cuda` build tag was set. It also broke the +project-wide rule that `go build ./...` must succeed on every supported +platform without CGo. + +Every other GPU runtime binding in ztensor (cuBLAS, cuDNN, cuRAND, TensorRT, +HIP/ROCm, OpenCL) is already loaded at runtime via `internal/cuda.DlopenPath` +and `internal/cuda.Ccall`. NCCL was the lone holdout. + +## Decision + +`internal/nccl` is implemented in Go-only via runtime dlopen of +`libnccl.so.2`, mirroring the pattern in `internal/cublas/cublas_purego.go`. + +Key points: + +- **No build tag** on `nccl_purego.go` โ€” it compiles on every platform. +- On non-linux GOOS, `loadNccl` returns a `nccl: not supported on $GOOS` + error without attempting `dlopen`. Every exported entry point surfaces this + as a clean error rather than a panic. +- ABI constants (`ncclSuccess`, the data-type and reduction-op enums, and + `NCCL_UNIQUE_ID_BYTES = 128`) are hardcoded against the stable NCCL 2.x ABI. +- `UniqueID` is marshaled as a fixed-size `[128]byte` array. `Bytes()` and + `UniqueIDFromBytes` provide the serialization round-trip used to ferry the + bootstrap blob between ranks. +- The legacy CGo implementation is retained as `nccl_cgo.go` behind + `//go:build cuda && cgo && nccl_cgo`. It is OFF by default and exists only + as a debugging fallback if a future NCCL release introduces an ABI quirk + that the dlopen path cannot handle. + +### AArch64 hidden-pointer ABI for ncclCommInitRank + +`ncclCommInitRank` takes the 128-byte `ncclUniqueId` **by value**: + +```c +ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, + ncclUniqueId commId, int rank); +``` + +The shared `cuda.Ccall` trampoline only marshals `uintptr`-sized arguments, +so passing 128 bytes by value is not directly possible. Fortunately the only +supported NCCL platform for ztensor is **linux/arm64** (the DGX Spark host). +Per AAPCS64 ยง6.8.2 rule B.4, *"if the argument type is a Composite Type that +is larger than 16 bytes, then the argument is copied to memory allocated by +the caller and the argument is replaced by a pointer to the copy."* That +means at the ABI level the third argument is already a pointer; passing +`uintptr(unsafe.Pointer(&uid.id[0]))` is the correct calling convention. + +If we ever need to support NCCL on linux/amd64 (System V ABI), this trick +will not work โ€” SysV passes large aggregates on the stack by value โ€” and we +will need a small assembly trampoline, or we can fall back to the CGo path +via the `nccl_cgo` build tag. + +## Consequences + +- `go build ./...` works everywhere with no `-tags cuda` and no system NCCL + installed. +- Tests in `internal/nccl/nccl_test.go` no longer require a build tag; they + call `requireNccl(t)` to skip when `libnccl.so.2` is not dlopen-able. Pure + constant and marshaling tests run on every platform. +- The duplicate `internal/nccl` copy inside the `zerfoo` repository is **not** + touched by this change and should be migrated in a follow-up. +- CI's `go vet` exclude list adds `/internal/nccl$` since the dlopen + trampoline relies on the same `unsafe.Pointer(uintptr(...))` pattern as + every other GPU runtime binding (already documented in `docs/QUALITY.md`). + +## References + +- Issue: zerfoo/ztensor#78 +- Reference pattern: `internal/cublas/cublas_purego.go` +- AAPCS64: diff --git a/internal/nccl/doc.go b/internal/nccl/doc.go index afecdf7..e9d8c58 100644 --- a/internal/nccl/doc.go +++ b/internal/nccl/doc.go @@ -1,4 +1,5 @@ -// Package nccl provides CGo bindings for the NVIDIA Collective Communications -// Library (NCCL). All functional code requires the "cuda" build tag and a -// working NCCL installation. +// Package nccl provides a zero-CGo binding for the NVIDIA Collective +// Communications Library (NCCL). The library is loaded at runtime via dlopen +// (see nccl_purego.go); a legacy CGo implementation is retained behind the +// `nccl_cgo` build tag for opt-in fallback (nccl_cgo.go). package nccl diff --git a/internal/nccl/nccl.go b/internal/nccl/nccl_cgo.go similarity index 94% rename from internal/nccl/nccl.go rename to internal/nccl/nccl_cgo.go index daa94bb..b69480b 100644 --- a/internal/nccl/nccl.go +++ b/internal/nccl/nccl_cgo.go @@ -1,9 +1,8 @@ -//go:build cuda +//go:build cuda && cgo && nccl_cgo -// Package nccl provides CGo bindings for the NVIDIA Collective Communications -// Library (NCCL). It exposes communicator lifecycle, all-reduce, and broadcast -// operations needed for multi-GPU gradient synchronization and tensor -// distribution. +// Legacy CGo binding for libnccl, retained as an opt-in fallback. Build with +// `-tags "cuda nccl_cgo"` to use this implementation instead of the default +// purego/dlopen path in nccl_purego.go. package nccl /* diff --git a/internal/nccl/nccl_purego.go b/internal/nccl/nccl_purego.go new file mode 100644 index 0000000..2fce6fb --- /dev/null +++ b/internal/nccl/nccl_purego.go @@ -0,0 +1,352 @@ +// Zero-CGo binding to libnccl.so.2 loaded at runtime via dlopen. On non-linux +// platforms every exported entry point returns a clean "not supported" error +// without attempting dlopen. On linux the library is dlopen'd lazily on first +// use; if libnccl.so.2 cannot be found the same error is returned. +// +// On AArch64 (DGX hardware) AAPCS64 rule B.4 means aggregates larger than 16 +// bytes are passed by hidden pointer, which lets us hand the 128-byte +// ncclUniqueId to ncclCommInitRank as a plain uintptr without an ABI +// trampoline. The legacy CGo implementation is retained behind the +// `nccl_cgo` build tag (see nccl_cgo.go). +package nccl + +import ( + "fmt" + "runtime" + "sync" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" +) + +// NCCL_UNIQUE_ID_BYTES is the fixed size of an ncclUniqueId. +const ncclUniqueIDBytes = 128 + +// NCCL result codes. +const ncclSuccess = 0 + +// NCCL data type enum (stable ABI for NCCL 2.x). +const ( + ncclInt8 = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclFloat32 = 7 + ncclFloat64 = 8 + ncclBfloat16 = 9 +) + +// NCCL reduction op enum (stable ABI for NCCL 2.x). +const ( + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 +) + +// DataType specifies the element type for NCCL operations. +type DataType int + +const ( + Float32 DataType = ncclFloat32 + Float64 DataType = ncclFloat64 + Int32 DataType = ncclInt32 + Int64 DataType = ncclInt64 +) + +// ReduceOp specifies the reduction operation for collective calls. +type ReduceOp int + +const ( + Sum ReduceOp = ncclSum + Avg ReduceOp = ncclAvg + Max ReduceOp = ncclMax + Min ReduceOp = ncclMin +) + +// ncclLib holds dlsym-resolved function pointers for libnccl.so.2. +type ncclLib struct { + getUniqueId uintptr // ncclGetUniqueId + commInitRank uintptr // ncclCommInitRank + commDestroy uintptr // ncclCommDestroy + commGetAsyncError uintptr // ncclCommGetAsyncError + allReduce uintptr // ncclAllReduce + broadcast uintptr // ncclBroadcast + groupStart uintptr // ncclGroupStart + groupEnd uintptr // ncclGroupEnd + getErrorString uintptr // ncclGetErrorString +} + +var ( + ncclLibInst *ncclLib + ncclOnce sync.Once + ncclLoadErr error +) + +// libnccl candidate paths (linux only). +var ncclLibPaths = []string{ + "libnccl.so.2", + "libnccl.so", +} + +func loadNccl() (*ncclLib, error) { + if runtime.GOOS != "linux" { + return nil, fmt.Errorf("nccl: not supported on %s", runtime.GOOS) + } + var handle uintptr + var lastErr string + for _, path := range ncclLibPaths { + h, err := cuda.DlopenPath(path) + if err == nil { + handle = h + break + } + lastErr = err.Error() + } + if handle == 0 { + return nil, fmt.Errorf("nccl: dlopen failed: %s", lastErr) + } + + lib := &ncclLib{} + type sym struct { + name string + ptr *uintptr + } + syms := []sym{ + {"ncclGetUniqueId", &lib.getUniqueId}, + {"ncclCommInitRank", &lib.commInitRank}, + {"ncclCommDestroy", &lib.commDestroy}, + {"ncclCommGetAsyncError", &lib.commGetAsyncError}, + {"ncclAllReduce", &lib.allReduce}, + {"ncclBroadcast", &lib.broadcast}, + {"ncclGroupStart", &lib.groupStart}, + {"ncclGroupEnd", &lib.groupEnd}, + {"ncclGetErrorString", &lib.getErrorString}, + } + for _, s := range syms { + addr, err := cuda.Dlsym(handle, s.name) + if err != nil { + return nil, fmt.Errorf("nccl: %w", err) + } + *s.ptr = addr + } + return lib, nil +} + +func getNcclLib() (*ncclLib, error) { + ncclOnce.Do(func() { + ncclLibInst, ncclLoadErr = loadNccl() + }) + return ncclLibInst, ncclLoadErr +} + +// Available returns true if libnccl can be loaded at runtime. +func Available() bool { + _, err := getNcclLib() + return err == nil +} + +// errorString returns the human-readable error string for an NCCL result code. +// Falls back to a numeric description if ncclGetErrorString cannot be invoked. +func (l *ncclLib) errorString(rc uintptr) string { + if l == nil || l.getErrorString == 0 { + return fmt.Sprintf("ncclResult=%d", rc) + } + cstr := cuda.Ccall(l.getErrorString, rc) + if cstr == 0 { + return fmt.Sprintf("ncclResult=%d", rc) + } + // Read C-string at cstr. + var b []byte + for i := 0; i < 1024; i++ { + c := *(*byte)(unsafe.Pointer(cstr + uintptr(i))) + if c == 0 { + break + } + b = append(b, c) + } + return string(b) +} + +// UniqueID wraps an ncclUniqueId (128-byte opaque blob) used to bootstrap +// communicator creation. +type UniqueID struct { + id [ncclUniqueIDBytes]byte +} + +// GetUniqueID generates a new unique ID for communicator initialization. +// Exactly one rank should call this and broadcast the result to all other ranks. +func GetUniqueID() (*UniqueID, error) { + lib, err := getNcclLib() + if err != nil { + return nil, err + } + uid := &UniqueID{} + rc := cuda.Ccall(lib.getUniqueId, uintptr(unsafe.Pointer(&uid.id[0]))) + if rc != ncclSuccess { + return nil, fmt.Errorf("ncclGetUniqueId failed: %s", lib.errorString(rc)) + } + return uid, nil +} + +// Bytes returns a copy of the raw bytes of the unique ID for serialization. +func (u *UniqueID) Bytes() []byte { + out := make([]byte, ncclUniqueIDBytes) + copy(out, u.id[:]) + return out +} + +// UniqueIDFromBytes reconstructs a UniqueID from raw bytes. +func UniqueIDFromBytes(b []byte) (*UniqueID, error) { + if len(b) != ncclUniqueIDBytes { + return nil, fmt.Errorf("UniqueIDFromBytes: expected %d bytes, got %d", ncclUniqueIDBytes, len(b)) + } + uid := &UniqueID{} + copy(uid.id[:], b) + return uid, nil +} + +// Comm wraps an ncclComm_t communicator (opaque pointer). +type Comm struct { + comm uintptr +} + +// InitRank initializes a communicator for a given rank in a group of nRanks. +// All ranks must call this with the same UniqueID and nRanks. The CUDA device +// for this rank must be set via cuda.SetDevice before calling InitRank. +// +// On AArch64 (AAPCS64) and other AAPCS-derived ABIs, aggregates larger than +// 16 bytes are passed by hidden pointer (rule B.4), so passing &uid.id[0] +// matches the C calling convention for ncclUniqueId-by-value. This binding +// is therefore correct on linux/arm64 (the supported NCCL platform); other +// ABIs (System V AMD64) pass large aggregates on the stack and would need a +// dedicated trampoline. +func InitRank(uid *UniqueID, nRanks, rank int) (*Comm, error) { + if uid == nil { + return nil, fmt.Errorf("nccl InitRank: nil UniqueID") + } + lib, err := getNcclLib() + if err != nil { + return nil, err + } + var comm uintptr + rc := cuda.Ccall(lib.commInitRank, + uintptr(unsafe.Pointer(&comm)), + uintptr(nRanks), + uintptr(unsafe.Pointer(&uid.id[0])), + uintptr(rank), + ) + if rc != ncclSuccess { + return nil, fmt.Errorf("ncclCommInitRank(nRanks=%d, rank=%d) failed: %s", + nRanks, rank, lib.errorString(rc)) + } + return &Comm{comm: comm}, nil +} + +// Destroy releases the communicator resources. +func (c *Comm) Destroy() error { + lib, err := getNcclLib() + if err != nil { + return err + } + rc := cuda.Ccall(lib.commDestroy, c.comm) + if rc != ncclSuccess { + return fmt.Errorf("ncclCommDestroy failed: %s", lib.errorString(rc)) + } + return nil +} + +// AllReduce performs an in-place all-reduce across all ranks. sendBuf and +// recvBuf may be the same pointer for in-place operation. count is the number +// of elements (not bytes). The stream parameter is a cudaStream_t as +// unsafe.Pointer. +func (c *Comm) AllReduce(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, op ReduceOp, stream unsafe.Pointer) error { + lib, err := getNcclLib() + if err != nil { + return err + } + rc := cuda.Ccall(lib.allReduce, + uintptr(sendBuf), + uintptr(recvBuf), + uintptr(count), + uintptr(dtype), + uintptr(op), + c.comm, + uintptr(stream), + ) + if rc != ncclSuccess { + return fmt.Errorf("ncclAllReduce failed: %s", lib.errorString(rc)) + } + return nil +} + +// Broadcast sends count elements from root's sendBuf to all ranks' recvBuf. +// For root, sendBuf and recvBuf may differ or be the same. +func (c *Comm) Broadcast(sendBuf, recvBuf unsafe.Pointer, count int, dtype DataType, root int, stream unsafe.Pointer) error { + lib, err := getNcclLib() + if err != nil { + return err + } + rc := cuda.Ccall(lib.broadcast, + uintptr(sendBuf), + uintptr(recvBuf), + uintptr(count), + uintptr(dtype), + uintptr(root), + c.comm, + uintptr(stream), + ) + if rc != ncclSuccess { + return fmt.Errorf("ncclBroadcast failed: %s", lib.errorString(rc)) + } + return nil +} + +// GroupStart begins a group of NCCL operations. All NCCL calls between +// GroupStart and GroupEnd are batched into a single launch. +func GroupStart() error { + lib, err := getNcclLib() + if err != nil { + return err + } + rc := cuda.Ccall(lib.groupStart) + if rc != ncclSuccess { + return fmt.Errorf("ncclGroupStart failed: %s", lib.errorString(rc)) + } + return nil +} + +// GroupEnd completes a group of NCCL operations and launches them. +func GroupEnd() error { + lib, err := getNcclLib() + if err != nil { + return err + } + rc := cuda.Ccall(lib.groupEnd) + if rc != ncclSuccess { + return fmt.Errorf("ncclGroupEnd failed: %s", lib.errorString(rc)) + } + return nil +} + +// GetAsyncError queries the communicator for any asynchronous errors that +// occurred during previous operations. +func (c *Comm) GetAsyncError() error { + lib, err := getNcclLib() + if err != nil { + return err + } + var result uintptr + rc := cuda.Ccall(lib.commGetAsyncError, c.comm, uintptr(unsafe.Pointer(&result))) + if rc != ncclSuccess { + return fmt.Errorf("ncclCommGetAsyncError query failed: %s", lib.errorString(rc)) + } + if result != ncclSuccess { + return fmt.Errorf("NCCL async error: %s", lib.errorString(result)) + } + return nil +} diff --git a/internal/nccl/nccl_test.go b/internal/nccl/nccl_test.go index 201f47f..3a09825 100644 --- a/internal/nccl/nccl_test.go +++ b/internal/nccl/nccl_test.go @@ -1,5 +1,3 @@ -//go:build cuda - package nccl import ( @@ -10,7 +8,49 @@ import ( "github.com/zerfoo/ztensor/internal/cuda" ) +// requireNccl skips the test when libnccl.so.2 cannot be dlopen'd. Pure +// constant/marshaling tests do not need this guard and should not call it. +func requireNccl(t *testing.T) { + t.Helper() + if !Available() { + t.Skip("libnccl.so.2 not available on this host") + } +} + +func TestConstants(t *testing.T) { + if Float32 != 7 || Float64 != 8 || Int32 != 2 || Int64 != 4 { + t.Fatalf("unexpected NCCL DataType ABI constants: f32=%d f64=%d i32=%d i64=%d", + Float32, Float64, Int32, Int64) + } + if Sum != 0 || Avg != 4 || Max != 2 || Min != 3 { + t.Fatalf("unexpected NCCL ReduceOp ABI constants: sum=%d avg=%d max=%d min=%d", + Sum, Avg, Max, Min) + } +} + +func TestUniqueIDFromBytesRoundTripNoLib(t *testing.T) { + // This exercises the pure-Go marshaling path and runs on every platform. + src := make([]byte, 128) + for i := range src { + src[i] = byte(i) + } + uid, err := UniqueIDFromBytes(src) + if err != nil { + t.Fatalf("UniqueIDFromBytes: %v", err) + } + out := uid.Bytes() + if len(out) != 128 { + t.Fatalf("Bytes length = %d, want 128", len(out)) + } + for i := range src { + if out[i] != src[i] { + t.Fatalf("byte %d: got %d want %d", i, out[i], src[i]) + } + } +} + func TestGetUniqueID(t *testing.T) { + requireNccl(t) uid, err := GetUniqueID() if err != nil { t.Fatalf("GetUniqueID: %v", err) @@ -22,6 +62,7 @@ func TestGetUniqueID(t *testing.T) { } func TestUniqueIDRoundTrip(t *testing.T) { + requireNccl(t) uid, err := GetUniqueID() if err != nil { t.Fatalf("GetUniqueID: %v", err) @@ -50,6 +91,7 @@ func TestUniqueIDFromBytesInvalidLength(t *testing.T) { } func TestSingleRankInitDestroy(t *testing.T) { + requireNccl(t) count, err := cuda.GetDeviceCount() if err != nil || count < 1 { t.Skip("requires at least 1 CUDA device") @@ -74,6 +116,7 @@ func TestSingleRankInitDestroy(t *testing.T) { } func TestSingleRankAllReduce(t *testing.T) { + requireNccl(t) count, err := cuda.GetDeviceCount() if err != nil || count < 1 { t.Skip("requires at least 1 CUDA device") @@ -134,6 +177,7 @@ func TestSingleRankAllReduce(t *testing.T) { } func TestTwoGPUAllReduce(t *testing.T) { + requireNccl(t) count, err := cuda.GetDeviceCount() if err != nil || count < 2 { t.Skip("requires at least 2 CUDA devices") @@ -231,6 +275,7 @@ func TestTwoGPUAllReduce(t *testing.T) { } func TestTwoGPUBroadcast(t *testing.T) { + requireNccl(t) count, err := cuda.GetDeviceCount() if err != nil || count < 2 { t.Skip("requires at least 2 CUDA devices") @@ -326,6 +371,7 @@ func TestTwoGPUBroadcast(t *testing.T) { } func TestGroupStartEnd(t *testing.T) { + requireNccl(t) // GroupStart/GroupEnd can be called without a communicator. if err := GroupStart(); err != nil { t.Fatalf("GroupStart: %v", err)