@@ -13,21 +13,32 @@ fn main() {}
1313
1414#[ cfg( not( target_os = "macos" ) ) ]
1515fn main ( ) {
16+ // Compile the bridge.cpp file
17+ let mut cc_builder = cc:: Build :: new ( ) ;
18+ cc_builder
19+ . cpp ( true )
20+ . file ( "src/bridge.cpp" )
21+ . flag ( "-std=c++14" ) ;
22+
23+ // Include CUDA headers
24+ if let Some ( cuda_home) = build_utils:: find_cuda_home ( ) {
25+ cc_builder. include ( format ! ( "{}/include" , cuda_home) ) ;
26+ }
27+
28+ cc_builder. compile ( "nccl_bridge" ) ;
29+
1630 let mut builder = bindgen:: Builder :: default ( )
17- . header ( "src/nccl .h" )
31+ . header ( "src/bridge .h" )
1832 . clang_arg ( "-x" )
1933 . clang_arg ( "c++" )
2034 . clang_arg ( "-std=c++14" )
21- . clang_arg ( format ! (
22- "-I{}/include" ,
23- build_utils:: find_cuda_home( ) . unwrap( )
24- ) )
2535 . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
26- // Communicator creation and management
27- . allowlist_function ( "ncclGetLastError" )
28- . allowlist_function ( "ncclGetErrorString" )
36+ // Version and error handling
2937 . allowlist_function ( "ncclGetVersion" )
3038 . allowlist_function ( "ncclGetUniqueId" )
39+ . allowlist_function ( "ncclGetErrorString" )
40+ . allowlist_function ( "ncclGetLastError" )
41+ // Communicator creation and management
3142 . allowlist_function ( "ncclCommInitRank" )
3243 . allowlist_function ( "ncclCommInitAll" )
3344 . allowlist_function ( "ncclCommInitRankConfig" )
@@ -60,15 +71,20 @@ fn main() {
6071 // User-defined reduction operators
6172 . allowlist_function ( "ncclRedOpCreatePreMulSum" )
6273 . allowlist_function ( "ncclRedOpDestroy" )
63- // Random nccl stuff we want
64- . allowlist_function ( "cudaStream.*" )
74+ // CUDA runtime functions
6575 . allowlist_function ( "cudaSetDevice" )
76+ . allowlist_function ( "cudaStreamSynchronize" )
77+ // Types
6678 . allowlist_type ( "ncclComm_t" )
6779 . allowlist_type ( "ncclResult_t" )
6880 . allowlist_type ( "ncclDataType_t" )
6981 . allowlist_type ( "ncclRedOp_t" )
7082 . allowlist_type ( "ncclScalarResidence_t" )
7183 . allowlist_type ( "ncclSimInfo_t" )
84+ . allowlist_type ( "ncclConfig_t" )
85+ . allowlist_type ( "cudaError_t" )
86+ . allowlist_type ( "cudaStream_t" )
87+ // Constants
7288 . allowlist_var ( "NCCL_SPLIT_NOCOLOR" )
7389 . allowlist_var ( "NCCL_MAJOR" )
7490 . allowlist_var ( "NCCL_MINOR" )
@@ -79,6 +95,11 @@ fn main() {
7995 is_global : false ,
8096 } ) ;
8197
98+ // Include CUDA headers
99+ if let Some ( cuda_home) = build_utils:: find_cuda_home ( ) {
100+ builder = builder. clang_arg ( format ! ( "-I{}/include" , cuda_home) ) ;
101+ }
102+
82103 // Include headers and libs from the active environment
83104 let python_config = match build_utils:: python_env_dirs ( ) {
84105 Ok ( config) => config,
@@ -103,13 +124,17 @@ fn main() {
103124
104125 // Write the bindings to the $OUT_DIR/bindings.rs file.
105126 let out_path = PathBuf :: from ( std:: env:: var ( "OUT_DIR" ) . unwrap ( ) ) ;
127+
128+ // Generate bindings (NCCL + CUDA runtime combined)
106129 builder
107130 . generate ( )
108131 . expect ( "Unable to generate bindings" )
109132 . write_to_file ( out_path. join ( "bindings.rs" ) )
110133 . expect ( "Couldn't write bindings!" ) ;
111134
112- println ! ( "cargo::rustc-link-lib=nccl" ) ;
135+ // We no longer link against nccl directly since we dlopen it
136+ // But we do link against CUDA runtime
137+ println ! ( "cargo::rustc-link-lib=cudart" ) ;
113138 println ! ( "cargo::rustc-cfg=cargo" ) ;
114139 println ! ( "cargo::rustc-check-cfg=cfg(cargo)" ) ;
115140}
0 commit comments