Skip to content

Commit bca67e4

Browse files
committed
Load nccl dynamically
Pull Request resolved: #2088 nccl can be a big library, but we do not want to force our build to try to load it on a machine that maybe doesn't have cuda. So we do the same as we do with the libcuda api and dynamically load it. ghstack-source-id: 328272743 @exported-using-ghexport Differential Revision: [D88672908](https://our.internmc.facebook.com/intern/diff/D88672908/)
1 parent b0e5118 commit bca67e4

File tree

13 files changed

+773
-65
lines changed

13 files changed

+773
-65
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ docs/_build/**
3333
docs/build/**
3434
docs/**/generated/**
3535
*/sg_execution_times.rst
36+
nccl/**

nccl-sys/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ serde = { version = "1.0.185", features = ["derive", "rc"] }
1313
[build-dependencies]
1414
bindgen = "0.70.1"
1515
build_utils = { path = "../build_utils" }
16+
cc = "1.0"

nccl-sys/build.rs

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,32 @@ fn main() {}
1313

1414
#[cfg(not(target_os = "macos"))]
1515
fn main() {
16+
// Compile the bridge.cpp file
17+
let mut cc_builder = cc::Build::new();
18+
cc_builder
19+
.cpp(true)
20+
.file("src/bridge.cpp")
21+
.flag("-std=c++14");
22+
23+
// Include CUDA headers
24+
if let Some(cuda_home) = build_utils::find_cuda_home() {
25+
cc_builder.include(format!("{}/include", cuda_home));
26+
}
27+
28+
cc_builder.compile("nccl_bridge");
29+
1630
let mut builder = bindgen::Builder::default()
17-
.header("src/nccl.h")
31+
.header("src/bridge.h")
1832
.clang_arg("-x")
1933
.clang_arg("c++")
2034
.clang_arg("-std=c++14")
21-
.clang_arg(format!(
22-
"-I{}/include",
23-
build_utils::find_cuda_home().unwrap()
24-
))
2535
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
26-
// Communicator creation and management
27-
.allowlist_function("ncclGetLastError")
28-
.allowlist_function("ncclGetErrorString")
36+
// Version and error handling
2937
.allowlist_function("ncclGetVersion")
3038
.allowlist_function("ncclGetUniqueId")
39+
.allowlist_function("ncclGetErrorString")
40+
.allowlist_function("ncclGetLastError")
41+
// Communicator creation and management
3142
.allowlist_function("ncclCommInitRank")
3243
.allowlist_function("ncclCommInitAll")
3344
.allowlist_function("ncclCommInitRankConfig")
@@ -60,15 +71,20 @@ fn main() {
6071
// User-defined reduction operators
6172
.allowlist_function("ncclRedOpCreatePreMulSum")
6273
.allowlist_function("ncclRedOpDestroy")
63-
// Random nccl stuff we want
64-
.allowlist_function("cudaStream.*")
74+
// CUDA runtime functions
6575
.allowlist_function("cudaSetDevice")
76+
.allowlist_function("cudaStreamSynchronize")
77+
// Types
6678
.allowlist_type("ncclComm_t")
6779
.allowlist_type("ncclResult_t")
6880
.allowlist_type("ncclDataType_t")
6981
.allowlist_type("ncclRedOp_t")
7082
.allowlist_type("ncclScalarResidence_t")
7183
.allowlist_type("ncclSimInfo_t")
84+
.allowlist_type("ncclConfig_t")
85+
.allowlist_type("cudaError_t")
86+
.allowlist_type("cudaStream_t")
87+
// Constants
7288
.allowlist_var("NCCL_SPLIT_NOCOLOR")
7389
.allowlist_var("NCCL_MAJOR")
7490
.allowlist_var("NCCL_MINOR")
@@ -79,6 +95,11 @@ fn main() {
7995
is_global: false,
8096
});
8197

98+
// Include CUDA headers
99+
if let Some(cuda_home) = build_utils::find_cuda_home() {
100+
builder = builder.clang_arg(format!("-I{}/include", cuda_home));
101+
}
102+
82103
// Include headers and libs from the active environment
83104
let python_config = match build_utils::python_env_dirs() {
84105
Ok(config) => config,
@@ -103,13 +124,17 @@ fn main() {
103124

104125
// Write the bindings to the $OUT_DIR/bindings.rs file.
105126
let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap());
127+
128+
// Generate bindings (NCCL + CUDA runtime combined)
106129
builder
107130
.generate()
108131
.expect("Unable to generate bindings")
109132
.write_to_file(out_path.join("bindings.rs"))
110133
.expect("Couldn't write bindings!");
111134

112-
println!("cargo::rustc-link-lib=nccl");
135+
// We no longer link against nccl directly since we dlopen it
136+
// But we do link against CUDA runtime
137+
println!("cargo::rustc-link-lib=cudart");
113138
println!("cargo::rustc-cfg=cargo");
114139
println!("cargo::rustc-check-cfg=cfg(cargo)");
115140
}

0 commit comments

Comments
 (0)