From 5dbcaa657f8103531ad1d3d0cc9e64f43cfbf54a Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Tue, 2 Dec 2025 16:41:13 +0000 Subject: [PATCH 01/12] Add hipify_torch submodule --- .gitmodules | 3 +++ deps/hipify_torch | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 deps/hipify_torch diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..291d65ed9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deps/hipify_torch"] + path = deps/hipify_torch + url = https://github.com/ROCm/hipify_torch diff --git a/deps/hipify_torch b/deps/hipify_torch new file mode 160000 index 000000000..ee928d80e --- /dev/null +++ b/deps/hipify_torch @@ -0,0 +1 @@ +Subproject commit ee928d80eb49a74be5d556465e04c6a40de7e3bc From e0b902df0574716810a55bba61a0640a891bb7f8 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Wed, 3 Dec 2025 21:31:44 +0000 Subject: [PATCH 02/12] build works, now a few unit tests fail --- build_utils/src/lib.rs | 249 +++++++- cuda-sys/build.rs | 268 ++++++-- monarch_rdma/build.rs | 95 ++- monarch_rdma/src/test_utils.rs | 31 +- rdmaxcel-sys/build.rs | 1090 +++++++++++++++++++++++++------- rdmaxcel-sys/src/lib.rs | 232 ++++--- 6 files changed, 1545 insertions(+), 420 deletions(-) diff --git a/build_utils/src/lib.rs b/build_utils/src/lib.rs index e210ad04b..22ee6c58a 100644 --- a/build_utils/src/lib.rs +++ b/build_utils/src/lib.rs @@ -8,12 +8,15 @@ //! Build utilities shared across monarch *-sys crates //! -//! This module provides common functionality for Python environment discovery -//! and CUDA installation detection used by various build scripts. +//! This module provides common functionality for Python environment discovery, +//! CUDA installation detection, and ROCm installation detection used by various +//! build scripts. use std::env; +use std::fs; use std::path::Path; use std::path::PathBuf; +use std::process::Command; use glob::glob; use which::which; @@ -64,6 +67,14 @@ pub struct CudaConfig { pub lib_dirs: Vec, } +/// Configuration structure for ROCm environment +#[derive(Debug, Clone, Default)] +pub struct RocmConfig { + pub rocm_home: Option, + pub include_dirs: Vec, + pub lib_dirs: Vec, +} + /// Result of Python environment discovery #[derive(Debug, Clone)] pub struct PythonConfig { @@ -75,6 +86,7 @@ pub struct PythonConfig { #[derive(Debug)] pub enum BuildError { CudaNotFound, + RocmNotFound, PythonNotFound, CommandFailed(String), PathNotFound(String), @@ -84,6 +96,7 @@ impl std::fmt::Display for BuildError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { BuildError::CudaNotFound => write!(f, "CUDA installation not found"), + BuildError::RocmNotFound => write!(f, "ROCm installation not found"), BuildError::PythonNotFound => write!(f, "Python interpreter not found"), BuildError::CommandFailed(cmd) => write!(f, "Command failed: {}", cmd), BuildError::PathNotFound(path) => write!(f, "Path not found: {}", path), @@ -99,6 +112,24 @@ pub fn get_env_var_with_rerun(name: &str) -> Result env::var(name) } +/// Finds the python interpreter, preferring `python3` if available. +/// +/// This function checks in order: +/// 1. PYO3_PYTHON environment variable +/// 2. `python3` command availability +/// 3. Falls back to `python` +pub fn find_python_interpreter() -> PathBuf { + get_env_var_with_rerun("PYO3_PYTHON") + .map(PathBuf::from) + .unwrap_or_else(|_| { + if Command::new("python3").arg("--version").output().is_ok() { + PathBuf::from("python3") + } else { + PathBuf::from("python") + } + }) +} + /// Find CUDA home directory using various heuristics /// /// This function attempts to locate CUDA installation through: @@ -276,6 +307,212 @@ pub fn print_cuda_lib_error_help() { eprintln!("Or: export CUDA_LIB_DIR=/usr/lib64"); } +/// Find ROCm home directory using various heuristics +/// +/// This function attempts to locate ROCm installation through: +/// 1. ROCM_HOME environment variable +/// 2. ROCM_PATH environment variable +/// 3. Platform-specific default locations (/opt/rocm-* or /opt/rocm) +/// 4. Finding hipcc in PATH and deriving ROCm home +pub fn find_rocm_home() -> Option { + // Guess #1: Environment variables + let mut rocm_home = get_env_var_with_rerun("ROCM_HOME") + .ok() + .or_else(|| get_env_var_with_rerun("ROCM_PATH").ok()); + + if rocm_home.is_none() { + // Guess #2: Platform-specific defaults (check these before PATH to avoid /usr) + // Check for versioned ROCm installations + let pattern = "/opt/rocm-*"; + if let Ok(entries) = glob(pattern) { + let mut rocm_homes: Vec<_> = entries.filter_map(Result::ok).collect(); + if !rocm_homes.is_empty() { + // Sort to get the most recent version + rocm_homes.sort(); + rocm_homes.reverse(); + rocm_home = Some(rocm_homes[0].to_string_lossy().into_owned()); + } + } + + // Fallback to /opt/rocm symlink + if rocm_home.is_none() { + let rocm_candidate = "/opt/rocm"; + if Path::new(rocm_candidate).exists() { + rocm_home = Some(rocm_candidate.to_string()); + } + } + + // Guess #3: Find hipcc in PATH (only if nothing else found) + if rocm_home.is_none() { + if let Ok(hipcc_path) = which("hipcc") { + // Get parent directory twice (hipcc is in ROCM_HOME/bin) + // But avoid using /usr as ROCm home + if let Some(rocm_dir) = hipcc_path.parent().and_then(|p| p.parent()) { + let rocm_str = rocm_dir.to_string_lossy(); + if rocm_str != "/usr" { + rocm_home = Some(rocm_str.into_owned()); + } + } + } + } + } + + rocm_home +} + +/// Detects ROCm version and returns (major, minor) or None if not found +/// +/// This function attempts to detect ROCm version through: +/// 1. Reading .info/version file in ROCm home +/// 2. Parsing hipcc --version output +/// +/// Returns `None` if version cannot be detected. Callers should provide +/// their own default (e.g., `.unwrap_or((6, 0))`). +pub fn get_rocm_version(rocm_home: &str) -> Option<(u32, u32)> { + // Try to read ROCm version from .info/version file + let version_file = PathBuf::from(rocm_home).join(".info").join("version"); + if let Ok(content) = fs::read_to_string(&version_file) { + let trimmed = content.trim(); + if let Some((major_str, rest)) = trimmed.split_once('.') { + if let Some((minor_str, _)) = rest.split_once('.') { + if let (Ok(major), Ok(minor)) = (major_str.parse::(), minor_str.parse::()) + { + println!( + "cargo:warning=Detected ROCm version {}.{} from {}", + major, + minor, + version_file.display() + ); + return Some((major, minor)); + } + } + } + } + + // Fallback: try hipcc --version + let hipcc_path = format!("{}/bin/hipcc", rocm_home); + if let Ok(output) = Command::new(&hipcc_path).arg("--version").output() { + let version_output = String::from_utf8_lossy(&output.stdout); + // Look for version pattern like "HIP version: 6.2.41134" + for line in version_output.lines() { + if line.contains("HIP version:") { + if let Some(version_part) = line.split("HIP version:").nth(1) { + let version_str = version_part.trim(); + if let Some((major_str, rest)) = version_str.split_once('.') { + if let Some((minor_str, _)) = rest.split_once('.') { + if let (Ok(major), Ok(minor)) = + (major_str.parse::(), minor_str.parse::()) + { + println!( + "cargo:warning=Detected ROCm version {}.{} from hipcc", + major, minor + ); + return Some((major, minor)); + } + } + } + } + } + } + } + + println!("cargo:warning=Could not detect ROCm version"); + None +} + +/// Discover ROCm configuration including home, include dirs, and lib dirs +pub fn discover_rocm_config() -> Result { + let rocm_home = find_rocm_home().ok_or(BuildError::RocmNotFound)?; + let rocm_home_path = PathBuf::from(&rocm_home); + + let mut config = RocmConfig { + rocm_home: Some(rocm_home_path.clone()), + include_dirs: Vec::new(), + lib_dirs: Vec::new(), + }; + + // Add standard include directories + for include_subdir in &["include", "include/hip"] { + let include_dir = rocm_home_path.join(include_subdir); + if include_dir.exists() { + config.include_dirs.push(include_dir); + } + } + + // Add standard library directories + for lib_subdir in &["lib", "lib64"] { + let lib_dir = rocm_home_path.join(lib_subdir); + if lib_dir.exists() { + config.lib_dirs.push(lib_dir); + break; // Use first found + } + } + + Ok(config) +} + +/// Validate ROCm installation exists and is complete +pub fn validate_rocm_installation() -> Result { + let rocm_config = discover_rocm_config()?; + let rocm_home = rocm_config.rocm_home.ok_or(BuildError::RocmNotFound)?; + let rocm_home_str = rocm_home.to_string_lossy().to_string(); + + // Verify ROCm include directory exists + let rocm_include_path = rocm_home.join("include"); + if !rocm_include_path.exists() { + return Err(BuildError::PathNotFound(format!( + "ROCm include directory at {}", + rocm_include_path.display() + ))); + } + + Ok(rocm_home_str) +} + +/// Get ROCm library directory +pub fn get_rocm_lib_dir() -> Result { + // Check if user explicitly set ROCM_LIB_DIR + if let Ok(rocm_lib_dir) = env::var("ROCM_LIB_DIR") { + return Ok(rocm_lib_dir); + } + + // Try to deduce from ROCm configuration + let rocm_config = discover_rocm_config()?; + if let Some(rocm_home) = rocm_config.rocm_home { + // Check both lib and lib64 + for lib_subdir in &["lib", "lib64"] { + let lib_path = rocm_home.join(lib_subdir); + if lib_path.exists() { + return Ok(lib_path.to_string_lossy().to_string()); + } + } + } + + Err(BuildError::PathNotFound( + "ROCm library directory".to_string(), + )) +} + +/// Print helpful error message for ROCm not found +pub fn print_rocm_error_help() { + eprintln!("Error: ROCm installation not found!"); + eprintln!("Please ensure ROCm is installed and one of the following is true:"); + eprintln!(" 1. Set ROCM_HOME environment variable to your ROCm installation directory"); + eprintln!(" 2. Set ROCM_PATH environment variable to your ROCm installation directory"); + eprintln!(" 3. Ensure 'hipcc' is in your PATH"); + eprintln!(" 4. Install ROCm to the default location (/opt/rocm on Linux)"); + eprintln!(); + eprintln!("Example: export ROCM_HOME=/opt/rocm-6.4.2"); +} + +/// Print helpful error message for ROCm lib dir not found +pub fn print_rocm_lib_error_help() { + eprintln!("Error: ROCm library directory not found!"); + eprintln!("Please set ROCM_LIB_DIR environment variable to your ROCm library directory."); + eprintln!(); + eprintln!("Example: export ROCM_LIB_DIR=/opt/rocm/lib"); +} + #[cfg(test)] mod tests { use super::*; @@ -288,6 +525,14 @@ mod tests { assert_eq!(result, Some("/test/cuda".to_string())); } + #[test] + fn test_find_rocm_home_env_var() { + env::set_var("ROCM_HOME", "/test/rocm"); + let result = find_rocm_home(); + env::remove_var("ROCM_HOME"); + assert_eq!(result, Some("/test/rocm".to_string())); + } + #[test] fn test_python_scripts_constants() { assert!(PYTHON_PRINT_DIRS.contains("sysconfig")); diff --git a/cuda-sys/build.rs b/cuda-sys/build.rs index 22b8b4675..8a596d838 100644 --- a/cuda-sys/build.rs +++ b/cuda-sys/build.rs @@ -7,47 +7,229 @@ */ use std::env; +use std::fs; +use std::path::Path; use std::path::PathBuf; +use std::process::Command; + +// --- HIPify Helper Functions (cuda-sys specific) --- + +/// Applies the required 'CUstream_st' typedef fix to the hipified header. +fn patch_hipified_header(hipified_file_path: &Path) -> Result<(), Box> { + println!("cargo:warning=Patching hipified header for CUstream_st typedef..."); + + let hip_typedef = "\n// HIP/ROCm Fix: Manually define CUstream_st for cxx bindings\ntypedef struct ihipStream_t CUstream_st;\n"; + + let original_content = fs::read_to_string(hipified_file_path)?; + let lines: Vec<&str> = original_content.lines().collect(); + let mut insert_index = 0; + + for (i, line) in lines.iter().enumerate() { + if !line.trim().starts_with("#include") + && !line.trim().is_empty() + && !line.trim().starts_with("//") + { + insert_index = i; + break; + } + if i == lines.len() - 1 { + insert_index = lines.len(); + } + } + + let mut new_content = String::new(); + for (i, line) in lines.iter().enumerate() { + if i == insert_index { + new_content.push_str(hip_typedef); + } + new_content.push_str(line); + new_content.push('\n'); + } + + fs::write( + hipified_file_path, + new_content.trim_end_matches('\n').as_bytes(), + )?; + + println!("cargo:warning=Successfully injected CUstream_st typedef."); + Ok(()) +} + +/// Runs `hipify_torch` on the source file. +/// Returns the path to the newly hipified header file. +fn hipify_source_header( + python_interpreter: &Path, + src_dir: &Path, + hip_src_dir: &Path, + file_name: &str, +) -> Result> { + println!( + "cargo:warning=Copying source header {} to {} for in-place hipify...", + file_name, + hip_src_dir.display() + ); + fs::create_dir_all(hip_src_dir)?; + + let src_file = src_dir.join(file_name); + let dest_file = hip_src_dir.join(file_name); + + if src_file.exists() { + fs::copy(&src_file, &dest_file)?; + println!("cargo:rerun-if-changed={}", src_file.display()); + } else { + return Err(format!("Source file {} not found", src_file.display()).into()); + } + + println!("cargo:warning=Running hipify_torch in-place on copied sources with --v2..."); + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let project_root = manifest_dir + .parent() + .ok_or("Failed to find project root: manifest parent not found")?; + + let hipify_script = project_root + .join("deps") + .join("hipify_torch") + .join("hipify_cli.py"); + + if !hipify_script.exists() { + return Err(format!("hipify_cli.py not found at {}", hipify_script.display()).into()); + } + println!("cargo:rerun-if-changed={}", hipify_script.display()); + + let hipify_output = Command::new(python_interpreter) + .arg(&hipify_script) + .arg("--project-directory") + .arg(hip_src_dir) + .arg("--v2") + .output()?; + + if !hipify_output.status.success() { + return Err(format!( + "hipify_cli.py failed: {}", + String::from_utf8_lossy(&hipify_output.stderr) + ) + .into()); + } + + println!("cargo:warning=Successfully hipified {} source", file_name); + + // The hipified output file name is wrapper_hip.h + let hip_file = hip_src_dir.join("wrapper_hip.h"); + + if hip_file.exists() { + patch_hipified_header(&hip_file)?; + Ok(hip_file) + } else { + let fallback_file = hip_src_dir.join(file_name); + if fallback_file.exists() { + patch_hipified_header(&fallback_file)?; + Ok(fallback_file) + } else { + Err(format!( + "Hipified output file not found. Expected: {}", + hip_file.display() + ) + .into()) + } + } +} + +// --- Main Build Logic --- #[cfg(target_os = "macos")] fn main() {} #[cfg(not(target_os = "macos"))] fn main() { - // Discover CUDA configuration including include and lib directories - let cuda_config = match build_utils::discover_cuda_config() { - Ok(config) => config, - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); + const CUDA_HEADER_NAME: &str = "wrapper.h"; + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + + // Check if we are building for ROCm (HIP) - check ROCm first + let is_rocm = build_utils::find_rocm_home().is_some(); + + println!("cargo:rerun-if-env-changed=USE_ROCM"); + + let header_path; + let compute_lib_names; + let compute_config; + + if is_rocm { + println!("cargo:warning=Using HIP from ROCm installation"); + compute_lib_names = vec!["amdhip64"]; + + // HIPify the CUDA wrapper header + let hip_src_dir = out_dir.join("hipified_src"); + let python_interpreter = build_utils::find_python_interpreter(); + + header_path = hipify_source_header( + &python_interpreter, + &manifest_dir.join("src"), + &hip_src_dir, + CUDA_HEADER_NAME, + ) + .expect("Failed to hipify wrapper.h"); + + // Discover ROCm configuration + match build_utils::discover_rocm_config() { + Ok(config) => { + compute_config = build_utils::CudaConfig { + cuda_home: config.rocm_home, + include_dirs: config.include_dirs, + lib_dirs: config.lib_dirs, + } + } + Err(_) => { + build_utils::print_rocm_error_help(); + std::process::exit(1); + } } - }; + } else { + println!("cargo:warning=Using CUDA"); + compute_lib_names = vec!["cuda", "cudart"]; + header_path = manifest_dir.join("src").join(CUDA_HEADER_NAME); - // Start building the bindgen configuration + match build_utils::discover_cuda_config() { + Ok(config) => compute_config = config, + Err(_) => { + build_utils::print_cuda_error_help(); + std::process::exit(1); + } + } + } + + // Configure bindgen let mut builder = bindgen::Builder::default() - // The input header we would like to generate bindings for - .header("src/wrapper.h") + .header(header_path.to_str().expect("Invalid header path")) .clang_arg("-x") .clang_arg("c++") .clang_arg("-std=gnu++20") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - // Allow the specified functions and types (CUDA Runtime API only) .allowlist_function("cuda.*") .allowlist_function("CUDA.*") .allowlist_type("cuda.*") .allowlist_type("CUDA.*") - // Use newtype enum style + .allowlist_type("CUstream_st") + .allowlist_function("hip.*") + .allowlist_type("hip.*") .default_enum_style(bindgen::EnumVariation::NewType { is_bitfield: false, is_global: false, }); - // Add CUDA include paths from the discovered configuration - for include_dir in &cuda_config.include_dirs { + for include_dir in &compute_config.include_dirs { builder = builder.clang_arg(format!("-I{}", include_dir.display())); } - // Include headers and libs from the active environment. + if is_rocm { + builder = builder + .clang_arg("-D__HIP_PLATFORM_AMD__=1") + .clang_arg("-DUSE_ROCM=1"); + } + + // Python environment let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { Ok(config) => config, Err(_) => { @@ -63,39 +245,41 @@ fn main() { builder = builder.clang_arg(format!("-I{}", include_dir)); } if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo::rustc-link-search=native={}", lib_dir); - // Set cargo metadata to inform dependent binaries about how to set their - // RPATH (see controller/build.rs for an example). - println!("cargo::metadata=LIB_PATH={}", lib_dir); + println!("cargo:rustc-link-search=native={}", lib_dir); + println!("cargo:metadata=LIB_PATH={}", lib_dir); } - // Get CUDA library directory and emit link directives - let cuda_lib_dir = match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); + // Link compute libraries + let compute_lib_dir = if is_rocm { + match build_utils::get_rocm_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_rocm_lib_error_help(); + std::process::exit(1); + } + } + } else { + match build_utils::get_cuda_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_cuda_lib_error_help(); + std::process::exit(1); + } } }; - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - println!("cargo:rustc-link-lib=cudart"); + println!("cargo:rustc-link-search=native={}", compute_lib_dir); + for lib_name in compute_lib_names { + println!("cargo:rustc-link-lib={}", lib_name); + } - // Generate bindings - fail fast if this doesn't work + // Generate bindings let bindings = builder.generate().expect("Unable to generate bindings"); - // Write the bindings to the $OUT_DIR/bindings.rs file - match env::var("OUT_DIR") { - Ok(out_dir) => { - let out_path = PathBuf::from(out_dir); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings"); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings"); - println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); - } - Err(_) => { - println!("Note: OUT_DIR not set, skipping bindings file generation"); - } - } + println!("cargo:rustc-cfg=cargo"); + println!("cargo:rustc-check-cfg=cfg(cargo)"); } diff --git a/monarch_rdma/build.rs b/monarch_rdma/build.rs index d76dab156..dfdb25e53 100644 --- a/monarch_rdma/build.rs +++ b/monarch_rdma/build.rs @@ -11,14 +11,42 @@ fn main() {} #[cfg(not(target_os = "macos"))] fn main() { - // Validate CUDA installation and get CUDA home path - let _cuda_home = match build_utils::validate_cuda_installation() { - Ok(home) => home, - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); + // Check if we are building for ROCm (HIP) - check ROCm first + let is_rocm = build_utils::find_rocm_home().is_some(); + + // Validate compute installation and set cfg flags + if is_rocm { + match build_utils::validate_rocm_installation() { + Ok(_) => println!("cargo:warning=Using ROCm/HIP for monarch_rdma"), + Err(_) => { + build_utils::print_rocm_error_help(); + std::process::exit(1); + } } - }; + + // Set ROCm version cfg flags + let rocm_version = build_utils::find_rocm_home() + .and_then(|home| build_utils::get_rocm_version(&home)) + .unwrap_or((6, 0)); + + if rocm_version.0 >= 7 { + println!("cargo:rustc-cfg=rocm_7_plus"); + } else { + println!("cargo:rustc-cfg=rocm_6_x"); + } + } else { + match build_utils::validate_cuda_installation() { + Ok(_) => println!("cargo:warning=Using CUDA for monarch_rdma"), + Err(_) => { + build_utils::print_cuda_error_help(); + std::process::exit(1); + } + } + } + + // Emit cfg check declarations + println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); // Include headers and libs from the active environment. let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { @@ -39,17 +67,34 @@ fn main() { println!("cargo:metadata=LIB_PATH={}", lib_dir); } - // Get CUDA library directory and emit link directives - let cuda_lib_dir = match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); + // Get compute library directory and emit link directives + let compute_lib_dir = if is_rocm { + match build_utils::get_rocm_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_rocm_lib_error_help(); + std::process::exit(1); + } + } + } else { + match build_utils::get_cuda_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_cuda_lib_error_help(); + std::process::exit(1); + } } }; - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - println!("cargo:rustc-link-lib=cuda"); - println!("cargo:rustc-link-lib=cudart"); + println!("cargo:rustc-link-search=native={}", compute_lib_dir); + + // Link compute libraries + if is_rocm { + println!("cargo:rustc-link-lib=amdhip64"); + println!("cargo:rustc-link-lib=hsa-runtime64"); + } else { + println!("cargo:rustc-link-lib=cuda"); + println!("cargo:rustc-link-lib=cudart"); + } // Link against the ibverbs and mlx5 libraries (used by rdmaxcel-sys) println!("cargo:rustc-link-lib=ibverbs"); @@ -72,7 +117,7 @@ fn main() { // Add library search path println!("cargo:rustc-link-search=native={}", path); // Set rpath so runtime linker can find the libraries - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", path); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path); } } } @@ -82,21 +127,25 @@ fn main() { println!("cargo:rustc-link-lib=torch_cpu"); println!("cargo:rustc-link-lib=torch"); println!("cargo:rustc-link-lib=c10"); - println!("cargo:rustc-link-lib=c10_cuda"); + if is_rocm { + println!("cargo:rustc-link-lib=c10_hip"); + } else { + println!("cargo:rustc-link-lib=c10_cuda"); + } } else { // Fallback to torch-sys links metadata if available if let Ok(torch_lib_path) = std::env::var("DEP_TORCH_LIB_PATH") { - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", torch_lib_path); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", torch_lib_path); } } // Set rpath for NCCL libraries if available if let Ok(nccl_lib_path) = std::env::var("DEP_NCCL_LIB_PATH") { - println!("cargo::rustc-link-arg=-Wl,-rpath,{}", nccl_lib_path); + println!("cargo:rustc-link-arg=-Wl,-rpath,{}", nccl_lib_path); } // Disable new dtags, as conda envs generally use `RPATH` over `RUNPATH` - println!("cargo::rustc-link-arg=-Wl,--disable-new-dtags"); + println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags"); // Link the static libraries from rdmaxcel-sys // Try the Cargo dependency mechanism first, then fall back to fixed paths @@ -144,6 +193,6 @@ fn main() { } // Set build configuration flags - println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); + println!("cargo:rustc-cfg=cargo"); + println!("cargo:rustc-check-cfg=cfg(cargo)"); } diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index 2f0370d59..a1c141e23 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -111,8 +111,8 @@ pub mod test_utils { unsafe impl Send for SendSyncCudaContext {} unsafe impl Sync for SendSyncCudaContext {} - /// Actor responsible for CUDA initialization and buffer management within its own process context. - /// This is important because you preform CUDA operations within the same process as the RDMA operations. + /// Actor responsible for CUDA initialization and buffer management within its own process context. + /// This is important because you preform CUDA operations within the same process as the RDMA operations. #[hyperactor::export( spawn = true, handlers = [ @@ -200,7 +200,10 @@ pub mod test_utils { .device .ok_or_else(|| anyhow::anyhow!("Device not initialized"))?; - let (dptr, padded_size) = unsafe { + // Convert dptr to usize inside the unsafe block before the await + // This is important because hipDeviceptr_t is *mut c_void (not Send) + // while CUdeviceptr is an integer type (Send) + let (dptr_usize, padded_size) = unsafe { cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0)); let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed(); @@ -213,8 +216,18 @@ pub mod test_utils { prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = device; prop.allocFlags.gpuDirectRDMACapable = 1; - prop.requestedHandleTypes = - rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + // HIP uses requestedHandleType (singular), CUDA uses requestedHandleTypes (plural) + #[cfg(any(rocm_6_x, rocm_7_plus))] + { + prop.requestedHandleType = + rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + } + #[cfg(not(any(rocm_6_x, rocm_7_plus)))] + { + prop.requestedHandleTypes = + rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + } cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity( &mut granularity as *mut usize, @@ -235,7 +248,7 @@ pub mod test_utils { &mut dptr, padded_size, 0, - 0, + std::ptr::null_mut(), 0, )); @@ -258,14 +271,14 @@ pub mod test_utils { 1 )); - (dptr, padded_size) + (dptr as usize, padded_size) }; let rdma_handle = rdma_actor - .request_buffer(cx, dptr as usize, padded_size) + .request_buffer(cx, dptr_usize, padded_size) .await?; - reply.send(cx, (rdma_handle, dptr as usize))?; + reply.send(cx, (rdma_handle, dptr_usize))?; Ok(()) } CudaActorMessage::FillBuffer { diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 0a78d67de..046ea7658 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -7,138 +7,592 @@ */ use std::env; +use std::fs; use std::path::Path; use std::path::PathBuf; +use std::process::Command; + +// ============================================================================= +// Hipify Patching Functions (specific to rdmaxcel-sys) +// ============================================================================= + +/// Renames rdmaxcel_cu* wrapper functions to rdmaxcel_hip* in the given content +/// NOTE: rdmaxcel_cuMemGetHandleForAddressRange is intentionally NOT included here +/// because for ROCm 6.x we replace it with HSA function, and for ROCm 7+ we handle it separately +fn rename_rdmaxcel_wrappers(content: &str) -> String { + content + // Memory management wrappers + .replace("rdmaxcel_cuMemGetAllocationGranularity", "rdmaxcel_hipMemGetAllocationGranularity") + .replace("rdmaxcel_cuMemCreate", "rdmaxcel_hipMemCreate") + .replace("rdmaxcel_cuMemAddressReserve", "rdmaxcel_hipMemAddressReserve") + .replace("rdmaxcel_cuMemMap", "rdmaxcel_hipMemMap") + .replace("rdmaxcel_cuMemSetAccess", "rdmaxcel_hipMemSetAccess") + .replace("rdmaxcel_cuMemUnmap", "rdmaxcel_hipMemUnmap") + .replace("rdmaxcel_cuMemAddressFree", "rdmaxcel_hipMemAddressFree") + .replace("rdmaxcel_cuMemRelease", "rdmaxcel_hipMemRelease") + .replace("rdmaxcel_cuMemcpyHtoD_v2", "rdmaxcel_hipMemcpyHtoD") + .replace("rdmaxcel_cuMemcpyDtoH_v2", "rdmaxcel_hipMemcpyDtoH") + .replace("rdmaxcel_cuMemsetD8_v2", "rdmaxcel_hipMemsetD8") + // Pointer queries + .replace("rdmaxcel_cuPointerGetAttribute", "rdmaxcel_hipPointerGetAttribute") + // Device management + .replace("rdmaxcel_cuInit", "rdmaxcel_hipInit") + .replace("rdmaxcel_cuDeviceGetCount", "rdmaxcel_hipDeviceGetCount") + .replace("rdmaxcel_cuDeviceGetAttribute", "rdmaxcel_hipDeviceGetAttribute") + .replace("rdmaxcel_cuDeviceGet", "rdmaxcel_hipDeviceGet") + // Context management + .replace("rdmaxcel_cuCtxCreate_v2", "rdmaxcel_hipCtxCreate") + .replace("rdmaxcel_cuCtxSetCurrent", "rdmaxcel_hipCtxSetCurrent") + // Error handling + .replace("rdmaxcel_cuGetErrorString", "rdmaxcel_hipGetErrorString") +} + +/// Post-processes hipified files for ROCm 7.0+ +fn patch_hipified_files_rocm7(hip_src_dir: &Path) -> Result<(), Box> { + println!("cargo:warning=Patching hipify_torch output for ROCm 7.0+..."); + + // --- Patch rdmaxcel_hip.cpp --- + let cpp_file = hip_src_dir.join("rdmaxcel_hip.cpp"); + if cpp_file.exists() { + let content = fs::read_to_string(&cpp_file)?; + + let patched_content = content + .replace( + "#include ", + "#include \n#include ", + ) + .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") + .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") + .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") + .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") + .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") + .replace("static_cast", "reinterpret_cast") + .replace("static_cast", "reinterpret_cast") + .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "hipMemRangeHandleTypeDmaBufFd") + .replace("cuMemGetHandleForAddressRange", "hipMemGetHandleForAddressRange") + .replace("CUDA_SUCCESS", "hipSuccess") + .replace("CUresult", "hipError_t"); + + fs::write(&cpp_file, patched_content)?; + } + + // --- Patch rdmaxcel_hip.h --- + let header_file = hip_src_dir.join("rdmaxcel_hip.h"); + if header_file.exists() { + let content = fs::read_to_string(&header_file)?; + let patched_content = content + .replace("#include \"driver_api.h\"", "#include \"driver_api_hip.h\"") + .replace("CUdeviceptr", "hipDeviceptr_t"); + + fs::write(&header_file, patched_content)?; + } + + // --- Patch driver_api_hip.h --- + let driver_api_h = hip_src_dir.join("driver_api_hip.h"); + if driver_api_h.exists() { + let content = fs::read_to_string(&driver_api_h)?; + let mut patched_content = rename_rdmaxcel_wrappers(&content); + + // For ROCm 7+, rename the dmabuf function and fix the type + patched_content = patched_content + .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") + .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); + + fs::write(&driver_api_h, patched_content)?; + } + + // --- Patch driver_api_hip.cpp --- + let driver_api_cpp = hip_src_dir.join("driver_api_hip.cpp"); + if driver_api_cpp.exists() { + let content = fs::read_to_string(&driver_api_cpp)?; + + let mut patched_content = rename_rdmaxcel_wrappers(&content); + + patched_content = patched_content + // Fix library name + .replace("libcuda.so.1", "libamdhip64.so") + // Rename the dmabuf function + .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") + // Fix the macro entry hipify missed + .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") + // Fix internal member references + .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") + .replace("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_(") + .replace("->cuMemCreate_(", "->hipMemCreate_(") + .replace("->cuMemAddressReserve_(", "->hipMemAddressReserve_(") + .replace("->cuMemMap_(", "->hipMemMap_(") + .replace("->cuMemSetAccess_(", "->hipMemSetAccess_(") + .replace("->cuMemUnmap_(", "->hipMemUnmap_(") + .replace("->cuMemAddressFree_(", "->hipMemAddressFree_(") + .replace("->cuMemRelease_(", "->hipMemRelease_(") + .replace("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_(") + .replace("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_(") + .replace("->cuMemsetD8_v2_(", "->hipMemsetD8_(") + .replace("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_(") + .replace("->cuInit_(", "->hipInit_(") + .replace("->cuDeviceGet_(", "->hipDeviceGet_(") + .replace("->cuDeviceGetCount_(", "->hipGetDeviceCount_(") + .replace("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_(") + .replace("->cuCtxCreate_v2_(", "->hipCtxCreate_(") + .replace("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_(") + .replace("->cuGetErrorString_(", "->hipDrvGetErrorString_(") + // Fix type + .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); + + fs::write(&driver_api_cpp, patched_content)?; + } + + println!("cargo:warning=Applied ROCm 7.0+ post-processing fixes to hipified files"); + Ok(()) +} + +/// Post-processes files for ROCm 6.x (uses HSA dmabuf instead of HIP dmabuf) +fn patch_hipified_files_rocm6(hip_src_dir: &Path) -> Result<(), Box> { + println!("cargo:warning=Patching hipify_torch output for ROCm 6.x (HSA dmabuf)..."); + + // --- Patch rdmaxcel_hip.cpp --- + let cpp_file = hip_src_dir.join("rdmaxcel_hip.cpp"); + if cpp_file.exists() { + let content = fs::read_to_string(&cpp_file)?; + + let mut patched_content = content + // Add version and HSA headers at the top + .replace( + "#include ", + "#include \n#include \n#include \n#include " + ) + // Fix PyTorch allocator namespace: c10::cuda → c10::hip + .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") + .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") + .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") + .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") + // Fix HIP API attribute names + .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") + // Fix pointer casts for HIP + .replace("static_cast", "reinterpret_cast") + .replace("static_cast", "reinterpret_cast") + // Replace CUDA types with HIP types + .replace("CUDA_SUCCESS", "hipSuccess") + .replace("CUdevice device", "hipDevice_t device") + // Fix device functions + .replace("cuDeviceGet(&device", "hipDeviceGet(&device") + .replace("cuDeviceGetAttribute", "hipDeviceGetAttribute") + .replace("cuPointerGetAttribute", "hipPointerGetAttribute") + // Fix device attribute constants + .replace("CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", "hipDeviceAttributePciBusId") + .replace("CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", "hipDeviceAttributePciDeviceId") + .replace("CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", "hipDeviceAttributePciDomainID") + .replace("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") + // Remove CUDA-specific constants for dmabuf type + .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "0 /* HSA dmabuf */"); + + // Replace cuMemGetHandleForAddressRange calls with hsa_amd_portable_export_dmabuf + // Note: The parameter order is different: + // CUDA: cuMemGetHandleForAddressRange(&fd, ptr, size, type, flags) + // HSA: hsa_amd_portable_export_dmabuf(ptr, size, &fd, nullptr) + patched_content = patched_content.replace( + "cuMemGetHandleForAddressRange(", + "hsa_amd_portable_export_dmabuf(", + ); + + // Fix the parameter ordering for hsa_amd_portable_export_dmabuf calls + // Pattern for compact_mrs function + patched_content = patched_content.replace( + "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(start_addr),\n total_size,\n 0 /* HSA dmabuf */,\n 0);", + "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(start_addr),\n total_size,\n &fd,\n nullptr);" + ); + + // Pattern for register_segments function + patched_content = patched_content.replace( + "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(chunk_start),\n chunk_size,\n 0 /* HSA dmabuf */,\n 0);", + "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(chunk_start),\n chunk_size,\n &fd,\n nullptr);" + ); + + // Replace result types and checks for HSA + patched_content = patched_content + .replace("CUresult cu_result", "hsa_status_t hsa_result") + .replace("hipError_t cu_result", "hsa_status_t hsa_result") + .replace("cu_result != hipSuccess", "hsa_result != HSA_STATUS_SUCCESS") + .replace("if (cu_result", "if (hsa_result"); + + // Fix hipPointerGetAttribute enum usage + patched_content = patched_content.replace( + "hipPointerAttribute::device", + "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL", + ); + + fs::write(&cpp_file, patched_content)?; + } + + // --- Patch rdmaxcel_hip.h --- + let header_file = hip_src_dir.join("rdmaxcel_hip.h"); + if header_file.exists() { + let content = fs::read_to_string(&header_file)?; + let patched_content = content + .replace("#include \"driver_api.h\"", "#include \"driver_api_hip.h\"") + .replace("CUdeviceptr", "hipDeviceptr_t"); + + fs::write(&header_file, patched_content)?; + } + + // --- Patch driver_api_hip.h for ROCm 6.x --- + // Key change: Replace CUmemRangeHandleType-based function with HSA dmabuf function + let driver_api_h = hip_src_dir.join("driver_api_hip.h"); + if driver_api_h.exists() { + let content = fs::read_to_string(&driver_api_h)?; + + // First apply standard renames (but NOT the dmabuf function - we replace it entirely) + let mut patched_content = rename_rdmaxcel_wrappers(&content); + + // Add HSA header (but don't duplicate if already present) + if !patched_content.contains("#include ") { + patched_content = patched_content.replace( + "#include ", + "#include \n#include \n#include " + ); + } + + // Replace the CUmemRangeHandleType-based function declaration with HSA version + // The hipified file still has CUDA names (hipify doesn't convert custom rdmaxcel_ functions) + let old_decl = "hipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n CUmemRangeHandleType handleType,\n unsigned long long flags);"; + let new_decl = "hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf(\n void* ptr,\n size_t size,\n int* fd,\n uint64_t* flags);"; + patched_content = patched_content.replace(old_decl, new_decl); + + // Also handle any remaining CUmemRangeHandleType references (shouldn't be any after above, but just in case) + patched_content = patched_content.replace("CUmemRangeHandleType", "int /* placeholder - ROCm 6.x */"); + + // Add CUDA-compatible wrapper declaration for monarch_rdma compatibility + patched_content.push_str("\n\n// CUDA-compatible wrapper for monarch_rdma\nhipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n int handleType,\n unsigned long long flags);\n"); + + fs::write(&driver_api_h, patched_content)?; + } + + // --- Patch driver_api_hip.cpp for ROCm 6.x --- + let driver_api_cpp = hip_src_dir.join("driver_api_hip.cpp"); + if driver_api_cpp.exists() { + let content = fs::read_to_string(&driver_api_cpp)?; + + // Apply standard wrapper renames first + let mut patched_content = rename_rdmaxcel_wrappers(&content); + + // Add HSA headers + patched_content = patched_content.replace( + "#include \"driver_api_hip.h\"", + "#include \"driver_api_hip.h\"\n#include \n#include " + ); + + // Fix library name + patched_content = patched_content.replace("libcuda.so.1", "libamdhip64.so"); + + // Fix const void* to void* conversion for hipMemcpyHtoD (ROCm API difference) + patched_content = patched_content.replace( + "dstDevice, srcHost, ByteCount);", + "dstDevice, const_cast(srcHost), ByteCount);" + ); + + // Fix internal member references (hipify doesn't convert struct member names) + patched_content = patched_content + .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") + .replace("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_(") + .replace("->cuMemCreate_(", "->hipMemCreate_(") + .replace("->cuMemAddressReserve_(", "->hipMemAddressReserve_(") + .replace("->cuMemMap_(", "->hipMemMap_(") + .replace("->cuMemSetAccess_(", "->hipMemSetAccess_(") + .replace("->cuMemUnmap_(", "->hipMemUnmap_(") + .replace("->cuMemAddressFree_(", "->hipMemAddressFree_(") + .replace("->cuMemRelease_(", "->hipMemRelease_(") + .replace("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_(") + .replace("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_(") + .replace("->cuMemsetD8_v2_(", "->hipMemsetD8_(") + .replace("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_(") + .replace("->cuInit_(", "->hipInit_(") + .replace("->cuDeviceGet_(", "->hipDeviceGet_(") + .replace("->cuDeviceGetCount_(", "->hipGetDeviceCount_(") + .replace("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_(") + .replace("->cuCtxCreate_v2_(", "->hipCtxCreate_(") + .replace("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_(") + .replace("->cuGetErrorString_(", "->hipDrvGetErrorString_("); + + // Fix the macro entries hipify missed + patched_content = patched_content + .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") + .replace("_(cuMemGetAllocationGranularity)", "_(hipMemGetAllocationGranularity)") + .replace("_(cuMemCreate)", "_(hipMemCreate)") + .replace("_(cuMemAddressReserve)", "_(hipMemAddressReserve)") + .replace("_(cuMemMap)", "_(hipMemMap)") + .replace("_(cuMemSetAccess)", "_(hipMemSetAccess)") + .replace("_(cuMemUnmap)", "_(hipMemUnmap)") + .replace("_(cuMemAddressFree)", "_(hipMemAddressFree)") + .replace("_(cuMemRelease)", "_(hipMemRelease)") + .replace("_(cuMemcpyHtoD_v2)", "_(hipMemcpyHtoD)") + .replace("_(cuMemcpyDtoH_v2)", "_(hipMemcpyDtoH)") + .replace("_(cuMemsetD8_v2)", "_(hipMemsetD8)") + .replace("_(cuPointerGetAttribute)", "_(hipPointerGetAttribute)") + .replace("_(cuInit)", "_(hipInit)") + .replace("_(cuDeviceGet)", "_(hipDeviceGet)") + .replace("_(cuDeviceGetCount)", "_(hipGetDeviceCount)") + .replace("_(cuDeviceGetAttribute)", "_(hipDeviceGetAttribute)") + .replace("_(cuCtxCreate_v2)", "_(hipCtxCreate)") + .replace("_(cuCtxSetCurrent)", "_(hipCtxSetCurrent)") + .replace("_(cuGetErrorString)", "_(hipDrvGetErrorString)"); + + // For ROCm 6.x: Replace the rdmaxcel_cuMemGetHandleForAddressRange wrapper function with HSA version + // The original implementation calls the DriverAPI member function, but for ROCm 6.x we call HSA directly + let old_wrapper = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + CUmemRangeHandleType handleType, + unsigned long long flags) { + return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( + handle, dptr, size, handleType, flags); +}"#; + + let new_wrapper = r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( + void* ptr, + size_t size, + int* fd, + uint64_t* flags) { + // Direct HSA call for ROCm 6.x - bypasses DriverAPI dynamic loading + return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); +}"#; + + patched_content = patched_content.replace(old_wrapper, new_wrapper); + + // Also try without the member function rename (in case replacement order matters) + let old_wrapper2 = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + CUmemRangeHandleType handleType, + unsigned long long flags) { + return rdmaxcel::DriverAPI::get()->cuMemGetHandleForAddressRange_( + handle, dptr, size, handleType, flags); +}"#; + patched_content = patched_content.replace(old_wrapper2, new_wrapper); + + // Handle any remaining CUmemRangeHandleType references + patched_content = patched_content.replace("CUmemRangeHandleType", "int /* placeholder - ROCm 6.x */"); + + // Remove hipMemGetHandleForAddressRange from the macro (we use direct HSA call instead) + // This prevents trying to dlsym a function that doesn't exist in ROCm 6.x + patched_content = patched_content.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange removed for ROCm 6.x - using HSA */ \\" + ); + patched_content = patched_content.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange removed for ROCm 6.x - using HSA */ \\" + ); + + // Add CUDA-compatible wrapper that monarch_rdma can call + let cuda_compat_wrapper = r#" + +// CUDA-compatible wrapper for monarch_rdma - translates to HSA call +hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + int handleType, + unsigned long long flags) { + (void)handleType; // unused - ROCm 6.x only supports dmabuf + (void)flags; // unused + hsa_status_t status = hsa_amd_portable_export_dmabuf( + reinterpret_cast(dptr), + size, + handle, + nullptr); + return (status == HSA_STATUS_SUCCESS) ? hipSuccess : hipErrorUnknown; +} +"#; + patched_content.push_str(cuda_compat_wrapper); + + fs::write(&driver_api_cpp, patched_content)?; + } + + println!("cargo:warning=Applied ROCm 6.x (HSA dmabuf) post-processing fixes to hipified files"); + Ok(()) +} + +/// Validates that hipified output files exist +fn validate_hipified_files(hip_src_dir: &Path) -> Result<(), Box> { + let required_files = [ + "rdmaxcel_hip.h", + "rdmaxcel_hip.c", + "rdmaxcel_hip.cpp", + "rdmaxcel.hip", + ]; + + for file_name in &required_files { + let file_path = hip_src_dir.join(file_name); + if !file_path.exists() { + return Err(format!( + "Required hipified file {} was not found in {}", + file_name, + hip_src_dir.display() + ) + .into()); + } + } + + Ok(()) +} + +/// Runs `hipify_torch` on the source directory. +fn hipify_sources( + python_interpreter: &Path, + src_dir: &Path, + hip_src_dir: &Path, + rocm_version: (u32, u32), +) -> Result<(), Box> { + println!( + "cargo:warning=Copying sources from {} to {} for in-place hipify...", + src_dir.display(), + hip_src_dir.display() + ); + fs::create_dir_all(hip_src_dir)?; + + // Include driver_api files for hipification + let files_to_copy = [ + "lib.rs", + "rdmaxcel.h", + "rdmaxcel.c", + "rdmaxcel.cpp", + "rdmaxcel.cu", + "test_rdmaxcel.c", + "driver_api.h", + "driver_api.cpp", + ]; + + for file_name in files_to_copy { + let src_file = src_dir.join(file_name); + let dest_file = hip_src_dir.join(file_name); + if src_file.exists() { + fs::copy(&src_file, &dest_file)?; + println!("cargo:rerun-if-changed={}", src_file.display()); + } + } + + println!("cargo:warning=Running hipify_torch in-place on copied sources with --v2..."); + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let project_root = manifest_dir + .parent() + .ok_or("Failed to find project root from manifest dir")?; + let hipify_script = project_root + .join("deps") + .join("hipify_torch") + .join("hipify_cli.py"); + + if !hipify_script.exists() { + return Err(format!("hipify_cli.py not found at {}", hipify_script.display()).into()); + } + println!("cargo:rerun-if-changed={}", hipify_script.display()); + + let hipify_output = Command::new(python_interpreter) + .arg(&hipify_script) + .arg("--project-directory") + .arg(hip_src_dir) + .arg("--v2") + .output()?; + + if !hipify_output.status.success() { + return Err(format!( + "hipify_cli.py failed: {}", + String::from_utf8_lossy(&hipify_output.stderr) + ) + .into()); + } + + // Apply version-specific patches + let (major, _minor) = rocm_version; + if major >= 7 { + patch_hipified_files_rocm7(hip_src_dir)?; + } else { + patch_hipified_files_rocm6(hip_src_dir)?; + } + + Ok(()) +} + +/// Gets libtorch include directories from PyTorch +fn get_libtorch_include_dirs(python_interpreter: &Path) -> Vec { + let mut include_dirs = Vec::new(); + + if let Ok(output) = Command::new(python_interpreter) + .arg("-c") + .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) + .output() + { + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { + include_dirs.push(PathBuf::from(path)); + } + } + } + + include_dirs +} + +// ============================================================================= +// Main Build Logic +// ============================================================================= #[cfg(target_os = "macos")] fn main() {} #[cfg(not(target_os = "macos"))] fn main() { - // Link against the ibverbs library println!("cargo:rustc-link-lib=ibverbs"); - - // Link against the mlx5 library println!("cargo:rustc-link-lib=mlx5"); - // Link against dl for dynamic loading - println!("cargo:rustc-link-lib=dl"); - - // Tell cargo to invalidate the built crate whenever the wrapper changes - println!("cargo:rerun-if-changed=src/rdmaxcel.h"); - println!("cargo:rerun-if-changed=src/rdmaxcel.c"); - println!("cargo:rerun-if-changed=src/rdmaxcel.cpp"); - println!("cargo:rerun-if-changed=src/driver_api.h"); - println!("cargo:rerun-if-changed=src/driver_api.cpp"); + let (is_rocm, compute_home, compute_lib_names, rocm_version) = + if let Ok(rocm_home) = build_utils::validate_rocm_installation() { + let version = build_utils::get_rocm_version(&rocm_home).unwrap_or((6, 0)); + println!( + "cargo:warning=Using HIP/ROCm {} from {}", + format!("{}.{}", version.0, version.1), + rocm_home + ); + + if version.0 >= 7 { + println!("cargo:rustc-cfg=rocm_7_plus"); + } else { + println!("cargo:rustc-cfg=rocm_6_x"); + } - // Validate CUDA installation and get CUDA home path - let cuda_home = match build_utils::validate_cuda_installation() { - Ok(home) => home, - Err(_) => { + (true, rocm_home, vec!["amdhip64", "hsa-runtime64"], version) + } else if let Ok(cuda_home) = build_utils::validate_cuda_installation() { + println!("cargo:warning=Using CUDA from {}", cuda_home); + (false, cuda_home, vec!["cuda", "cudart"], (0, 0)) + } else { + eprintln!("Error: Neither CUDA nor ROCm installation found!"); build_utils::print_cuda_error_help(); + build_utils::print_rocm_error_help(); std::process::exit(1); - } - }; + }; + + // Emit cfg check declarations + println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); - // Get the directory of the current crate - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| { - // For buck2 run, we know the package is in fbcode/monarch/rdmaxcel-sys - // Get the fbsource directory from the current directory path + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| { let current_dir = std::env::current_dir().expect("Failed to get current directory"); let current_path = current_dir.to_string_lossy(); - - // Find the fbsource part of the path if let Some(fbsource_pos) = current_path.find("fbsource") { let fbsource_path = ¤t_path[..fbsource_pos + "fbsource".len()]; format!("{}/fbcode/monarch/rdmaxcel-sys", fbsource_path) } else { - // If we can't find fbsource in the path, just use the current directory format!("{}/src", current_dir.to_string_lossy()) } - }); + })); + let src_dir = manifest_dir.join("src"); - // Create the absolute path to the header file - let header_path = format!("{}/src/rdmaxcel.h", manifest_dir); + let python_interpreter = build_utils::find_python_interpreter(); - // Check if the header file exists - if !Path::new(&header_path).exists() { - panic!("Header file not found at {}", header_path); - } + let compute_include_path = format!("{}/include", compute_home); + println!("cargo:rustc-env=CUDA_INCLUDE_PATH={}", compute_include_path); - // Start building the bindgen configuration - let mut builder = bindgen::Builder::default() - // The input header we would like to generate bindings for - .header(&header_path) - .clang_arg("-x") - .clang_arg("c++") - .clang_arg("-std=gnu++20") - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - // Allow the specified functions, types, and variables - .allowlist_function("ibv_.*") - .allowlist_function("mlx5dv_.*") - .allowlist_function("mlx5_wqe_.*") - .allowlist_function("create_qp") - .allowlist_function("create_mlx5dv_.*") - .allowlist_function("register_cuda_memory") - .allowlist_function("db_ring") - .allowlist_function("cqe_poll") - .allowlist_function("send_wqe") - .allowlist_function("recv_wqe") - .allowlist_function("launch_db_ring") - .allowlist_function("launch_cqe_poll") - .allowlist_function("launch_send_wqe") - .allowlist_function("launch_recv_wqe") - .allowlist_function("rdma_get_active_segment_count") - .allowlist_function("rdma_get_all_segment_info") - .allowlist_function("pt_cuda_allocator_compatibility") - .allowlist_function("register_segments") - .allowlist_function("deregister_segments") - .allowlist_function("rdmaxcel_cu.*") - .allowlist_function("get_cuda_pci_address_from_ptr") - .allowlist_function("rdmaxcel_print_device_info") - .allowlist_function("rdmaxcel_error_string") - .allowlist_function("rdmaxcel_qp_.*") - .allowlist_function("poll_cq_with_cache") - .allowlist_function("completion_cache_.*") - .allowlist_type("ibv_.*") - .allowlist_type("mlx5dv_.*") - .allowlist_type("mlx5_wqe_.*") - .allowlist_type("cqe_poll_result_t") - .allowlist_type("wqe_params_t") - .allowlist_type("cqe_poll_params_t") - .allowlist_type("rdma_segment_info_t") - .allowlist_type("rdmaxcel_qp_t") - .allowlist_type("rdmaxcel_qp") - .allowlist_type("completion_cache_t") - .allowlist_type("completion_cache") - .allowlist_type("poll_context_t") - .allowlist_type("poll_context") - .allowlist_var("MLX5_.*") - .allowlist_var("IBV_.*") - // Block specific types that are manually defined in lib.rs - .blocklist_type("ibv_wc") - .blocklist_type("mlx5_wqe_ctrl_seg") - // Apply the same bindgen flags as in the BUCK file - .bitfield_enum("ibv_access_flags") - .bitfield_enum("ibv_qp_attr_mask") - .bitfield_enum("ibv_wc_flags") - .bitfield_enum("ibv_send_flags") - .bitfield_enum("ibv_port_cap_flags") - .constified_enum_module("ibv_qp_type") - .constified_enum_module("ibv_qp_state") - .constified_enum_module("ibv_port_state") - .constified_enum_module("ibv_wc_opcode") - .constified_enum_module("ibv_wr_opcode") - .constified_enum_module("ibv_wc_status") - .derive_default(true) - .prepend_enum_name(false); - - // Add CUDA include path (we already validated it exists) - let cuda_include_path = format!("{}/include", cuda_home); - println!("cargo:rustc-env=CUDA_INCLUDE_PATH={}", cuda_include_path); - builder = builder.clang_arg(format!("-I{}", cuda_include_path)); - - // Include headers and libs from the active environment. let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { Ok(config) => config, Err(_) => { @@ -150,34 +604,44 @@ fn main() { } }; - if let Some(include_dir) = &python_config.include_dir { - builder = builder.clang_arg(format!("-I{}", include_dir)); - } - if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo:rustc-link-search=native={}", lib_dir); - println!("cargo:metadata=LIB_PATH={}", lib_dir); - } - - // Get CUDA library directory and emit link directives - let cuda_lib_dir = match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); + let compute_lib_dir = if is_rocm { + match build_utils::get_rocm_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_rocm_lib_error_help(); + std::process::exit(1); + } + } + } else { + match build_utils::get_cuda_lib_dir() { + Ok(dir) => dir, + Err(_) => { + build_utils::print_cuda_lib_error_help(); + std::process::exit(1); + } } }; - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - // Note: libcuda is now loaded dynamically via dlopen in driver_api.cpp - // Only link cudart (CUDA Runtime API) - println!("cargo:rustc-link-lib=cudart"); + println!("cargo:rustc-link-search=native={}", compute_lib_dir); + for lib_name in &compute_lib_names { + println!("cargo:rustc-link-lib={}", lib_name); + } - // Link PyTorch C++ libraries for c10 symbols let use_pytorch_apis = build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") .unwrap_or_else(|_| "1".to_owned()); + + let libtorch_include_dirs: Vec = if use_pytorch_apis == "1" { + get_libtorch_include_dirs(&python_interpreter) + } else { + build_utils::get_env_var_with_rerun("LIBTORCH_INCLUDE") + .unwrap_or_default() + .split(':') + .filter(|s| !s.is_empty()) + .map(PathBuf::from) + .collect() + }; + if use_pytorch_apis == "1" { - // Try to get PyTorch library directory - let python_interpreter = std::path::PathBuf::from("python"); - if let Ok(output) = std::process::Command::new(&python_interpreter) + if let Ok(output) = Command::new(&python_interpreter) .arg("-c") .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) .output() @@ -189,170 +653,308 @@ fn main() { } } } - // Link core PyTorch libraries needed for C10 symbols println!("cargo:rustc-link-lib=torch_cpu"); println!("cargo:rustc-link-lib=torch"); println!("cargo:rustc-link-lib=c10"); + if is_rocm { + println!("cargo:rustc-link-lib=c10_hip"); + } else { + println!("cargo:rustc-link-lib=c10_cuda"); + } } - // Generate bindings - let bindings = builder.generate().expect("Unable to generate bindings"); + // Link dl for dynamic loading + println!("cargo:rustc-link-lib=dl"); - // Write the bindings to the $OUT_DIR/bindings.rs file match env::var("OUT_DIR") { Ok(out_dir) => { - // Export OUT_DIR so dependent crates can find our compiled libraries - println!("cargo:out_dir={}", out_dir); - - let out_path = PathBuf::from(&out_dir); - match bindings.write_to_file(out_path.join("bindings.rs")) { - Ok(_) => { - println!("cargo:rustc-cfg=cargo"); - println!("cargo:rustc-check-cfg=cfg(cargo)"); - } - Err(e) => eprintln!("Warning: Couldn't write bindings: {}", e), - } + let out_path = PathBuf::from(out_dir); + println!("cargo:out_dir={}", out_path.display()); - // Compile the C source file - let c_source_path = format!("{}/src/rdmaxcel.c", manifest_dir); - if Path::new(&c_source_path).exists() { - let mut build = cc::Build::new(); - build - .file(&c_source_path) - .include(format!("{}/src", manifest_dir)) - .flag("-fPIC"); + let (code_dir, header_path, c_source_path, cpp_source_path, cuda_source_path, driver_api_cpp_path); - // Add CUDA include paths - reuse the paths we already found for bindgen - build.include(&cuda_include_path); + if is_rocm { + let hip_src_dir = out_path.join("hipified_src"); - build.compile("rdmaxcel"); + hipify_sources(&python_interpreter, &src_dir, &hip_src_dir, rocm_version) + .expect("Failed to hipify sources"); + + validate_hipified_files(&hip_src_dir).expect("Hipified files validation failed"); + + code_dir = hip_src_dir.clone(); + header_path = hip_src_dir.join("rdmaxcel_hip.h"); + c_source_path = hip_src_dir.join("rdmaxcel_hip.c"); + cpp_source_path = hip_src_dir.join("rdmaxcel_hip.cpp"); + cuda_source_path = hip_src_dir.join("rdmaxcel.hip"); + driver_api_cpp_path = hip_src_dir.join("driver_api_hip.cpp"); } else { - panic!("C source file not found at {}", c_source_path); + println!("cargo:rerun-if-changed={}/src/rdmaxcel.h", manifest_dir.display()); + println!("cargo:rerun-if-changed={}/src/rdmaxcel.c", manifest_dir.display()); + println!("cargo:rerun-if-changed={}/src/rdmaxcel.cpp", manifest_dir.display()); + println!("cargo:rerun-if-changed={}/src/rdmaxcel.cu", manifest_dir.display()); + println!("cargo:rerun-if-changed={}/src/driver_api.h", manifest_dir.display()); + println!("cargo:rerun-if-changed={}/src/driver_api.cpp", manifest_dir.display()); + + code_dir = src_dir.clone(); + header_path = src_dir.join("rdmaxcel.h"); + c_source_path = src_dir.join("rdmaxcel.c"); + cpp_source_path = src_dir.join("rdmaxcel.cpp"); + cuda_source_path = src_dir.join("rdmaxcel.cu"); + driver_api_cpp_path = src_dir.join("driver_api.cpp"); } - // Compile the C++ source file for CUDA allocator compatibility - let cpp_source_path = format!("{}/src/rdmaxcel.cpp", manifest_dir); - let driver_api_cpp_path = format!("{}/src/driver_api.cpp", manifest_dir); - if Path::new(&cpp_source_path).exists() && Path::new(&driver_api_cpp_path).exists() { - let mut libtorch_include_dirs: Vec = vec![]; - - // Use the same approach as torch-sys: Python discovery first, env vars as fallback - let use_pytorch_apis = - build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") - .unwrap_or_else(|_| "1".to_owned()); - - if use_pytorch_apis == "1" { - // Use Python to get PyTorch include paths (same as torch-sys) - let python_interpreter = PathBuf::from("python"); - let output = std::process::Command::new(&python_interpreter) - .arg("-c") - .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) - .output() - .unwrap_or_else(|_| panic!("error running {python_interpreter:?}")); + if !header_path.exists() { + panic!("Header file not found at {}", header_path.display()); + } - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { - libtorch_include_dirs.push(PathBuf::from(path)); - } - } + let mut builder = bindgen::Builder::default() + .header(header_path.to_string_lossy()) + .clang_arg("-x") + .clang_arg("c++") + .clang_arg("-std=gnu++20") + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .allowlist_function("ibv_.*") + .allowlist_function("mlx5dv_.*") + .allowlist_function("mlx5_wqe_.*") + .allowlist_function("create_qp") + .allowlist_function("create_mlx5dv_.*") + .allowlist_function("register_cuda_memory") + .allowlist_function("register_hip_memory") + .allowlist_function("db_ring") + .allowlist_function("cqe_poll") + .allowlist_function("send_wqe") + .allowlist_function("recv_wqe") + .allowlist_function("launch_db_ring") + .allowlist_function("launch_cqe_poll") + .allowlist_function("launch_send_wqe") + .allowlist_function("launch_recv_wqe") + .allowlist_function("rdma_get_active_segment_count") + .allowlist_function("rdma_get_all_segment_info") + .allowlist_function("pt_cuda_allocator_compatibility") + .allowlist_function("pt_hip_allocator_compatibility") + .allowlist_function("register_segments") + .allowlist_function("deregister_segments") + .allowlist_function("register_dmabuf_buffer") + .allowlist_function("get_hip_pci_address_from_ptr") + // Driver API wrappers (CUDA/HIP/HSA) + .allowlist_function("rdmaxcel_cu.*") + .allowlist_function("rdmaxcel_hip.*") + .allowlist_function("rdmaxcel_hsa.*") + // QP management functions + .allowlist_function("rdmaxcel_qp_.*") + .allowlist_function("rdmaxcel_print_device_info") + .allowlist_function("rdmaxcel_error_string") + // Completion cache functions + .allowlist_function("completion_cache_.*") + .allowlist_function("poll_cq_with_cache") + // Types for QP and completion handling + .allowlist_type("rdmaxcel_qp_t") + .allowlist_type("rdmaxcel_qp") + .allowlist_type("rdmaxcel_error_code_t") + .allowlist_type("completion_cache_t") + .allowlist_type("completion_cache") + .allowlist_type("completion_node_t") + .allowlist_type("completion_node") + .allowlist_type("poll_context_t") + .allowlist_type("poll_context") + .allowlist_type("rdma_qp_type_t") + // CUDA types (for CUDA builds) + .allowlist_type("CUdeviceptr") + .allowlist_type("CUdevice") + .allowlist_type("CUresult") + .allowlist_type("CUcontext") + .allowlist_type("CUmemRangeHandleType") + .allowlist_var("CUDA_SUCCESS") + .allowlist_var("CU_.*") + // HIP types (for ROCm builds) + .allowlist_type("hipDeviceptr_t") + .allowlist_type("hipDevice_t") + .allowlist_type("hipError_t") + .allowlist_type("hipCtx_t") + .allowlist_type("hipPointer_attribute") + .allowlist_var("hipSuccess") + .allowlist_var("HIP_.*") + // HSA types (for ROCm 6.x dmabuf) + .allowlist_type("hsa_status_t") + .allowlist_var("HSA_STATUS_SUCCESS") + .allowlist_type("ibv_.*") + .allowlist_type("mlx5dv_.*") + .allowlist_type("mlx5_wqe_.*") + .allowlist_type("cqe_poll_result_t") + .allowlist_type("wqe_params_t") + .allowlist_type("cqe_poll_params_t") + .allowlist_type("rdma_segment_info_t") + .allowlist_var("MLX5_.*") + .allowlist_var("IBV_.*") + .allowlist_var("RDMA_QP_TYPE_.*") + .blocklist_type("ibv_wc") + .blocklist_type("mlx5_wqe_ctrl_seg") + .bitfield_enum("ibv_access_flags") + .bitfield_enum("ibv_qp_attr_mask") + .bitfield_enum("ibv_wc_flags") + .bitfield_enum("ibv_send_flags") + .bitfield_enum("ibv_port_cap_flags") + .constified_enum_module("ibv_qp_type") + .constified_enum_module("ibv_qp_state") + .constified_enum_module("ibv_port_state") + .constified_enum_module("ibv_wc_opcode") + .constified_enum_module("ibv_wr_opcode") + .constified_enum_module("ibv_wc_status") + .derive_default(true) + .prepend_enum_name(false); + + builder = builder.clang_arg(format!("-I{}", compute_include_path)); + + // NOTE: Do NOT add libtorch include paths to bindgen + + if is_rocm { + builder = builder + .clang_arg("-D__HIP_PLATFORM_AMD__=1") + .clang_arg("-DUSE_ROCM=1"); + + if rocm_version.0 >= 7 { + builder = builder.clang_arg("-DROCM_7_PLUS=1"); } else { - // Use environment variables (fallback approach) - libtorch_include_dirs.extend( - build_utils::get_env_var_with_rerun("LIBTORCH_INCLUDE") - .unwrap_or_default() - .split(':') - .filter(|s| !s.is_empty()) - .map(PathBuf::from), - ); + builder = builder.clang_arg("-DROCM_6_X=1"); } + } + + if let Some(include_dir) = &python_config.include_dir { + builder = builder.clang_arg(format!("-I{}", include_dir)); + } + + let bindings = builder.generate().expect("Unable to generate bindings"); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings"); + + println!("cargo:rustc-cfg=cargo"); + println!("cargo:rustc-check-cfg=cfg(cargo)"); + + if c_source_path.exists() { + let mut build = cc::Build::new(); + build.file(&c_source_path).include(&code_dir).flag("-fPIC"); + build.include(&compute_include_path); + if is_rocm { + build.define("__HIP_PLATFORM_AMD__", "1"); + build.define("USE_ROCM", "1"); + if rocm_version.0 >= 7 { + build.define("ROCM_7_PLUS", "1"); + } else { + build.define("ROCM_6_X", "1"); + } + } + build.compile("rdmaxcel"); + } + if cpp_source_path.exists() { let mut cpp_build = cc::Build::new(); cpp_build .file(&cpp_source_path) - .file(&driver_api_cpp_path) - .include(format!("{}/src", manifest_dir)) + .include(&code_dir) .flag("-fPIC") .cpp(true) .flag("-std=gnu++20") + .flag("-Wno-unused-parameter") .define("PYTORCH_C10_DRIVER_API_SUPPORTED", "1"); - // Add CUDA include paths - cpp_build.include(&cuda_include_path); + if driver_api_cpp_path.exists() { + cpp_build.file(&driver_api_cpp_path); + } - // Add PyTorch/C10 include paths + cpp_build.include(&compute_include_path); + if is_rocm { + cpp_build.define("__HIP_PLATFORM_AMD__", "1"); + cpp_build.define("USE_ROCM", "1"); + if rocm_version.0 >= 7 { + cpp_build.define("ROCM_7_PLUS", "1"); + } else { + cpp_build.define("ROCM_6_X", "1"); + } + } for include_dir in &libtorch_include_dirs { cpp_build.include(include_dir); } - - // Add Python include path if available if let Some(include_dir) = &python_config.include_dir { cpp_build.include(include_dir); } - cpp_build.compile("rdmaxcel_cpp"); - } else { - if !Path::new(&cpp_source_path).exists() { - panic!("C++ source file not found at {}", cpp_source_path); - } - if !Path::new(&driver_api_cpp_path).exists() { - panic!( - "Driver API C++ source file not found at {}", - driver_api_cpp_path - ); - } } - // Compile the CUDA source file - let cuda_source_path = format!("{}/src/rdmaxcel.cu", manifest_dir); - if Path::new(&cuda_source_path).exists() { - // Use the CUDA home path we already validated - let nvcc_path = format!("{}/bin/nvcc", cuda_home); - - // Set up fixed output directory - use a predictable path instead of dynamic OUT_DIR - let cuda_build_dir = format!("{}/target/cuda_build", manifest_dir); + + if cuda_source_path.exists() { + let (compiler_path, compiler_name) = if is_rocm { + (format!("{}/bin/hipcc", compute_home), "hipcc") + } else { + (format!("{}/bin/nvcc", compute_home), "nvcc") + }; + + let cuda_build_dir = format!("{}/target/cuda_build", manifest_dir.display()); std::fs::create_dir_all(&cuda_build_dir) .expect("Failed to create CUDA build directory"); - let cuda_obj_path = format!("{}/rdmaxcel_cuda.o", cuda_build_dir); let cuda_lib_path = format!("{}/librdmaxcel_cuda.a", cuda_build_dir); - // Use nvcc to compile the CUDA file - let nvcc_output = std::process::Command::new(&nvcc_path) - .args(&[ + let compiler_output = if is_rocm { + let mut cmd = Command::new(&compiler_path); + cmd.args([ "-c", - &cuda_source_path, + cuda_source_path.to_str().unwrap(), "-o", &cuda_obj_path, - "--compiler-options", "-fPIC", "-std=c++20", - "--expt-extended-lambda", - "-Xcompiler", - "-fPIC", - &format!("-I{}", cuda_include_path), - &format!("-I{}/src", manifest_dir), - &format!("-I/usr/include"), - &format!("-I/usr/include/infiniband"), - ]) - .output(); + "-D__HIP_PLATFORM_AMD__=1", + "-DUSE_ROCM=1", + &format!("-I{}", compute_include_path), + &format!("-I{}", code_dir.display()), + "-I/usr/include", + "-I/usr/include/infiniband", + ]); + + if rocm_version.0 >= 7 { + cmd.arg("-DROCM_7_PLUS=1"); + } else { + cmd.arg("-DROCM_6_X=1"); + } - match nvcc_output { + cmd.output() + } else { + Command::new(&compiler_path) + .args([ + "-c", + cuda_source_path.to_str().unwrap(), + "-o", + &cuda_obj_path, + "--compiler-options", + "-fPIC", + "-std=c++20", + "--expt-extended-lambda", + "-Xcompiler", + "-fPIC", + &format!("-I{}", compute_include_path), + &format!("-I{}", code_dir.display()), + "-I/usr/include", + "-I/usr/include/infiniband", + ]) + .output() + }; + + match compiler_output { Ok(output) => { if !output.status.success() { - eprintln!("nvcc stderr: {}", String::from_utf8_lossy(&output.stderr)); - eprintln!("nvcc stdout: {}", String::from_utf8_lossy(&output.stdout)); - panic!("Failed to compile CUDA source with nvcc"); + eprintln!("{} stderr: {}", compiler_name, String::from_utf8_lossy(&output.stderr)); + eprintln!("{} stdout: {}", compiler_name, String::from_utf8_lossy(&output.stdout)); + panic!("Failed to compile CUDA/HIP source with {}", compiler_name); } - println!("cargo:rerun-if-changed={}", cuda_source_path); } Err(e) => { - eprintln!("Failed to run nvcc: {}", e); - panic!("nvcc not found or failed to execute"); + eprintln!("Failed to run {}: {}", compiler_name, e); + panic!("{} not found or failed to execute", compiler_name); } } - // Create static library from the compiled CUDA object - let ar_output = std::process::Command::new("ar") - .args(&["rcs", &cuda_lib_path, &cuda_obj_path]) + let ar_output = Command::new("ar") + .args(["rcs", &cuda_lib_path, &cuda_obj_path]) .output(); match ar_output { @@ -361,14 +963,12 @@ fn main() { eprintln!("ar stderr: {}", String::from_utf8_lossy(&output.stderr)); panic!("Failed to create CUDA static library with ar"); } - // Emit metadata so dependent crates can find this library println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); println!("cargo:rustc-link-search=native={}", cuda_build_dir); - - // Copy the library to OUT_DIR as well for Cargo dependency mechanism - if let Err(e) = - std::fs::copy(&cuda_lib_path, format!("{}/librdmaxcel_cuda.a", out_dir)) - { + if let Err(e) = std::fs::copy( + &cuda_lib_path, + format!("{}/librdmaxcel_cuda.a", out_path.display()), + ) { eprintln!("Warning: Failed to copy CUDA library to OUT_DIR: {}", e); } } @@ -377,8 +977,6 @@ fn main() { panic!("ar not found or failed to execute"); } } - } else { - panic!("CUDA source file not found at {}", cuda_source_path); } } Err(_) => { diff --git a/rdmaxcel-sys/src/lib.rs b/rdmaxcel-sys/src/lib.rs index 546fd84ad..e199ed1ef 100644 --- a/rdmaxcel-sys/src/lib.rs +++ b/rdmaxcel-sys/src/lib.rs @@ -88,102 +88,16 @@ mod inner { } /// Returns the number of bytes transferred. - /// - /// Relevant if the Receive Queue for incoming Send or RDMA Write with immediate operations. - /// This value doesn't include the length of the immediate data, if such exists. Relevant in - /// the Send Queue for RDMA Read and Atomic operations. - /// - /// For the Receive Queue of a UD QP that is not associated with an SRQ or for an SRQ that is - /// associated with a UD QP this value equals to the payload of the message plus the 40 bytes - /// reserved for the GRH. The number of bytes transferred is the payload of the message plus - /// the 40 bytes reserved for the GRH, whether or not the GRH is present pub fn len(&self) -> usize { self.byte_len as usize } /// Check if this work requested completed successfully. - /// - /// A successful work completion (`IBV_WC_SUCCESS`) means that the corresponding Work Request - /// (and all of the unsignaled Work Requests that were posted previous to it) ended, and the - /// memory buffers that this Work Request refers to are ready to be (re)used. pub fn is_valid(&self) -> bool { self.status == ibv_wc_status::IBV_WC_SUCCESS } - /// Returns the work completion status and vendor error syndrome (`vendor_err`) if the work - /// request did not completed successfully. - /// - /// Possible statuses include: - /// - /// - `IBV_WC_LOC_LEN_ERR`: Local Length Error: this happens if a Work Request that was posted - /// in a local Send Queue contains a message that is greater than the maximum message size - /// that is supported by the RDMA device port that should send the message or an Atomic - /// operation which its size is different than 8 bytes was sent. This also may happen if a - /// Work Request that was posted in a local Receive Queue isn't big enough for holding the - /// incoming message or if the incoming message size if greater the maximum message size - /// supported by the RDMA device port that received the message. - /// - `IBV_WC_LOC_QP_OP_ERR`: Local QP Operation Error: an internal QP consistency error was - /// detected while processing this Work Request: this happens if a Work Request that was - /// posted in a local Send Queue of a UD QP contains an Address Handle that is associated - /// with a Protection Domain to a QP which is associated with a different Protection Domain - /// or an opcode which isn't supported by the transport type of the QP isn't supported (for - /// example: - /// RDMA Write over a UD QP). - /// - `IBV_WC_LOC_EEC_OP_ERR`: Local EE Context Operation Error: an internal EE Context - /// consistency error was detected while processing this Work Request (unused, since its - /// relevant only to RD QPs or EE Context, which aren’t supported). - /// - `IBV_WC_LOC_PROT_ERR`: Local Protection Error: the locally posted Work Request’s buffers - /// in the scatter/gather list does not reference a Memory Region that is valid for the - /// requested operation. - /// - `IBV_WC_WR_FLUSH_ERR`: Work Request Flushed Error: A Work Request was in process or - /// outstanding when the QP transitioned into the Error State. - /// - `IBV_WC_MW_BIND_ERR`: Memory Window Binding Error: A failure happened when tried to bind - /// a MW to a MR. - /// - `IBV_WC_BAD_RESP_ERR`: Bad Response Error: an unexpected transport layer opcode was - /// returned by the responder. Relevant for RC QPs. - /// - `IBV_WC_LOC_ACCESS_ERR`: Local Access Error: a protection error occurred on a local data - /// buffer during the processing of a RDMA Write with Immediate operation sent from the - /// remote node. Relevant for RC QPs. - /// - `IBV_WC_REM_INV_REQ_ERR`: Remote Invalid Request Error: The responder detected an - /// invalid message on the channel. Possible causes include the operation is not supported - /// by this receive queue (qp_access_flags in remote QP wasn't configured to support this - /// operation), insufficient buffering to receive a new RDMA or Atomic Operation request, or - /// the length specified in a RDMA request is greater than 2^{31} bytes. Relevant for RC - /// QPs. - /// - `IBV_WC_REM_ACCESS_ERR`: Remote Access Error: a protection error occurred on a remote - /// data buffer to be read by an RDMA Read, written by an RDMA Write or accessed by an - /// atomic operation. This error is reported only on RDMA operations or atomic operations. - /// Relevant for RC QPs. - /// - `IBV_WC_REM_OP_ERR`: Remote Operation Error: the operation could not be completed - /// successfully by the responder. Possible causes include a responder QP related error that - /// prevented the responder from completing the request or a malformed WQE on the Receive - /// Queue. Relevant for RC QPs. - /// - `IBV_WC_RETRY_EXC_ERR`: Transport Retry Counter Exceeded: The local transport timeout - /// retry counter was exceeded while trying to send this message. This means that the remote - /// side didn't send any Ack or Nack. If this happens when sending the first message, - /// usually this mean that the connection attributes are wrong or the remote side isn't in a - /// state that it can respond to messages. If this happens after sending the first message, - /// usually it means that the remote QP isn't available anymore. Relevant for RC QPs. - /// - `IBV_WC_RNR_RETRY_EXC_ERR`: RNR Retry Counter Exceeded: The RNR NAK retry count was - /// exceeded. This usually means that the remote side didn't post any WR to its Receive - /// Queue. Relevant for RC QPs. - /// - `IBV_WC_LOC_RDD_VIOL_ERR`: Local RDD Violation Error: The RDD associated with the QP - /// does not match the RDD associated with the EE Context (unused, since its relevant only - /// to RD QPs or EE Context, which aren't supported). - /// - `IBV_WC_REM_INV_RD_REQ_ERR`: Remote Invalid RD Request: The responder detected an - /// invalid incoming RD message. Causes include a Q_Key or RDD violation (unused, since its - /// relevant only to RD QPs or EE Context, which aren't supported) - /// - `IBV_WC_REM_ABORT_ERR`: Remote Aborted Error: For UD or UC QPs associated with a SRQ, - /// the responder aborted the operation. - /// - `IBV_WC_INV_EECN_ERR`: Invalid EE Context Number: An invalid EE Context number was - /// detected (unused, since its relevant only to RD QPs or EE Context, which aren't - /// supported). - /// - `IBV_WC_INV_EEC_STATE_ERR`: Invalid EE Context State Error: Operation is not legal for - /// the specified EE Context state (unused, since its relevant only to RD QPs or EE Context, - /// which aren't supported). - /// - `IBV_WC_FATAL_ERR`: Fatal Error. - /// - `IBV_WC_RESP_TIMEOUT_ERR`: Response Timeout Error. - /// - `IBV_WC_GENERAL_ERR`: General Error: other error which isn't one of the above errors. + /// Returns the work completion status and vendor error syndrome if failed. pub fn error(&self) -> Option<(ibv_wc_status::Type, u32)> { match self.status { ibv_wc_status::IBV_WC_SUCCESS => None, @@ -192,20 +106,11 @@ mod inner { } /// Returns the operation that the corresponding Work Request performed. - /// - /// This value controls the way that data was sent, the direction of the data flow and the - /// valid attributes in the Work Completion. pub fn opcode(&self) -> ibv_wc_opcode::Type { self.opcode } - /// Returns a 32 bits number, in network order, in an SEND or RDMA WRITE opcodes that is being - /// sent along with the payload to the remote side and placed in a Receive Work Completion and - /// not in a remote memory buffer - /// - /// Note that IMM is only returned if `IBV_WC_WITH_IMM` is set in `wc_flags`. If this is not - /// the case, no immediate value was provided, and `imm_data` should be interpreted - /// differently. See `man ibv_poll_cq` for details. + /// Returns immediate data if present. pub fn imm_data(&self) -> Option { if self.is_valid() && ((self.wc_flags & ibv_wc_flags::IBV_WC_WITH_IMM).0 != 0) { Some(self.imm_data) @@ -238,7 +143,138 @@ mod inner { pub use inner::*; -// RDMA error string function and CUDA utility functions +// ============================================================================= +// ROCm/HIP Compatibility Aliases +// ============================================================================= +// These allow monarch_rdma to use CUDA names transparently on ROCm builds. + +// --- Basic Type Aliases --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipDeviceptr_t as CUdeviceptr; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipDevice_t as CUdevice; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipError_t as CUresult; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipCtx_t as CUcontext; + +// --- Memory Allocation Types --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemGenericAllocationHandle_t as CUmemGenericAllocationHandle; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemAllocationProp as CUmemAllocationProp; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemAccessDesc as CUmemAccessDesc; + +// --- Status/Success Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::HIP_SUCCESS as CUDA_SUCCESS; + +// --- Pointer Attribute Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE as CU_POINTER_ATTRIBUTE_MEMORY_TYPE; + +// --- Memory Allocation Type Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemAllocationTypePinned as CU_MEM_ALLOCATION_TYPE_PINNED; + +// --- Memory Location Type Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemLocationTypeDevice as CU_MEM_LOCATION_TYPE_DEVICE; + +// --- Memory Handle Type Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemHandleTypePosixFileDescriptor as CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + +// --- Memory Allocation Granularity Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemAllocationGranularityMinimum as CU_MEM_ALLOC_GRANULARITY_MINIMUM; + +// --- Memory Access Flags Constants --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::hipMemAccessFlagsProtReadWrite as CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + +// --- Dmabuf Handle Type Constants --- +#[cfg(rocm_6_x)] +pub const CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD: i32 = 0; + +#[cfg(rocm_7_plus)] +pub use inner::hipMemRangeHandleTypeDmaBufFd as CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD; + +// --- Driver Init/Device Functions --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipInit as rdmaxcel_cuInit; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipDeviceGet as rdmaxcel_cuDeviceGet; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipDeviceGetCount as rdmaxcel_cuDeviceGetCount; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipPointerGetAttribute as rdmaxcel_cuPointerGetAttribute; + +// --- Context Functions --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipCtxCreate as rdmaxcel_cuCtxCreate_v2; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipCtxSetCurrent as rdmaxcel_cuCtxSetCurrent; + +// --- Error Handling Functions --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipGetErrorString as rdmaxcel_cuGetErrorString; + +// --- Memory Management Functions --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemGetAllocationGranularity as rdmaxcel_cuMemGetAllocationGranularity; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemCreate as rdmaxcel_cuMemCreate; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemAddressReserve as rdmaxcel_cuMemAddressReserve; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemMap as rdmaxcel_cuMemMap; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemSetAccess as rdmaxcel_cuMemSetAccess; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemUnmap as rdmaxcel_cuMemUnmap; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemAddressFree as rdmaxcel_cuMemAddressFree; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemRelease as rdmaxcel_cuMemRelease; + +// --- Memory Copy/Set Functions --- +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemcpyHtoD as rdmaxcel_cuMemcpyHtoD_v2; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemcpyDtoH as rdmaxcel_cuMemcpyDtoH_v2; + +#[cfg(any(rocm_6_x, rocm_7_plus))] +pub use inner::rdmaxcel_hipMemsetD8 as rdmaxcel_cuMemsetD8_v2; + +// --- Dmabuf Function --- +// ROCm 7+: direct alias to HIP function +#[cfg(rocm_7_plus)] +pub use inner::rdmaxcel_hipMemGetHandleForAddressRange as rdmaxcel_cuMemGetHandleForAddressRange; + +// ROCm 6.x: uses the CUDA-compatible wrapper we added in build.rs that internally calls HSA +#[cfg(rocm_6_x)] +pub use inner::rdmaxcel_cuMemGetHandleForAddressRange; + +// RDMA error string function and utility functions unsafe extern "C" { pub fn rdmaxcel_error_string(error_code: std::os::raw::c_int) -> *const std::os::raw::c_char; pub fn get_cuda_pci_address_from_ptr( From ac75664f4d127560b5752953ab1677d0e64fafb5 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 11 Dec 2025 16:16:02 +0000 Subject: [PATCH 03/12] added rdmaxcel_cpp so does not overwrite rdmaxcel library --- rdmaxcel-sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index daacf49fd..58e7bdc05 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -533,7 +533,7 @@ fn main() { } for include_dir in &libtorch_include_dirs { cpp_build.include(include_dir); } if let Some(include_dir) = &python_config.include_dir { cpp_build.include(include_dir); } - cpp_build.compile("rdmaxcel"); + cpp_build.compile("rdmaxcel_cpp"); } // Compile CUDA/HIP files From d5f48e4294af0419a17cf35ed8469aba986be1e7 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 11 Dec 2025 17:06:36 +0000 Subject: [PATCH 04/12] nccl-sys now is hipified --- nccl-sys/build.rs | 137 ++++++++++++++++++++++++++++++++++++++++---- nccl-sys/src/lib.rs | 8 +++ 2 files changed, 135 insertions(+), 10 deletions(-) diff --git a/nccl-sys/build.rs b/nccl-sys/build.rs index 9b4f7a97c..58bdb8f4a 100644 --- a/nccl-sys/build.rs +++ b/nccl-sys/build.rs @@ -6,22 +6,125 @@ * LICENSE file in the root directory of this source tree. */ +use std::env; +use std::fs; +use std::path::Path; use std::path::PathBuf; +use std::process::Command; #[cfg(target_os = "macos")] fn main() {} +/// Hipify the nccl.h header for ROCm compatibility +fn hipify_sources( + python_interpreter: &Path, + src_dir: &Path, + hip_src_dir: &Path, +) -> Result<(), Box> { + println!( + "cargo:warning=nccl-sys: Copying sources from {} to {} for hipify...", + src_dir.display(), + hip_src_dir.display() + ); + fs::create_dir_all(hip_src_dir)?; + + // Copy header file + let src_file = src_dir.join("nccl.h"); + let dest_file = hip_src_dir.join("nccl.h"); + if src_file.exists() { + fs::copy(&src_file, &dest_file)?; + println!("cargo:rerun-if-changed={}", src_file.display()); + } + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let project_root = manifest_dir + .parent() + .ok_or("Failed to find project root")?; + let hipify_script = project_root + .join("deps") + .join("hipify_torch") + .join("hipify_cli.py"); + + println!("cargo:warning=nccl-sys: Running hipify_torch..."); + let hipify_output = Command::new(python_interpreter) + .arg(&hipify_script) + .arg("--project-directory") + .arg(hip_src_dir) + .arg("--v2") + .arg("--output-directory") + .arg(hip_src_dir) + .output()?; + + if !hipify_output.status.success() { + return Err(format!( + "hipify_cli.py failed: {}", + String::from_utf8_lossy(&hipify_output.stderr) + ) + .into()); + } + + println!("cargo:warning=nccl-sys: hipify complete"); + Ok(()) +} + #[cfg(not(target_os = "macos"))] fn main() { + // Auto-detect ROCm vs CUDA using build_utils + let (is_rocm, compute_home, _rocm_version) = + if let Ok(rocm_home) = build_utils::validate_rocm_installation() { + let version = build_utils::get_rocm_version(&rocm_home).unwrap_or((6, 0)); + println!( + "cargo:warning=nccl-sys: Using RCCL from ROCm {}.{} at {}", + version.0, version.1, rocm_home + ); + println!("cargo:rustc-cfg=rocm"); + if version.0 >= 7 { + println!("cargo:rustc-cfg=rocm_7_plus"); + } else { + println!("cargo:rustc-cfg=rocm_6_x"); + } + (true, rocm_home, version) + } else if let Ok(cuda_home) = build_utils::validate_cuda_installation() { + println!( + "cargo:warning=nccl-sys: Using NCCL from CUDA at {}", + cuda_home + ); + (false, cuda_home, (0, 0)) + } else { + eprintln!("Error: Neither CUDA nor ROCm installation found!"); + std::process::exit(1); + }; + + // Emit cfg check declarations + println!("cargo:rustc-check-cfg=cfg(rocm)"); + println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let src_dir = manifest_dir.join("src"); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let python_interpreter = build_utils::find_python_interpreter(); + let compute_include_path = format!("{}/include", compute_home); + + // Determine which header to use + let header_path = if is_rocm { + // Hipify the sources for ROCm + let hip_src_dir = out_path.join("hipified_src"); + hipify_sources(&python_interpreter, &src_dir, &hip_src_dir) + .expect("Failed to hipify nccl-sys sources"); + + // The hipified header should now include + hip_src_dir.join("nccl_hip.h") + } else { + src_dir.join("nccl.h") + }; + let mut builder = bindgen::Builder::default() - .header("src/nccl.h") + .header(header_path.to_string_lossy()) .clang_arg("-x") .clang_arg("c++") .clang_arg("-std=c++14") - .clang_arg(format!( - "-I{}/include", - build_utils::find_cuda_home().unwrap() - )) + .clang_arg(format!("-I{}", compute_include_path)) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) // Communicator creation and management .allowlist_function("ncclGetLastError") @@ -60,9 +163,11 @@ fn main() { // User-defined reduction operators .allowlist_function("ncclRedOpCreatePreMulSum") .allowlist_function("ncclRedOpDestroy") - // Random nccl stuff we want + // CUDA/HIP stream and device functions .allowlist_function("cudaStream.*") + .allowlist_function("hipStream.*") .allowlist_function("cudaSetDevice") + .allowlist_function("hipSetDevice") .allowlist_type("ncclComm_t") .allowlist_type("ncclResult_t") .allowlist_type("ncclDataType_t") @@ -80,6 +185,13 @@ fn main() { is_global: false, }); + // Add platform-specific defines for bindgen + if is_rocm { + builder = builder + .clang_arg("-D__HIP_PLATFORM_AMD__=1") + .clang_arg("-DUSE_ROCM=1"); + } + // Include headers and libs from the active environment let python_config = match build_utils::python_env_dirs() { Ok(config) => config, @@ -97,20 +209,25 @@ fn main() { } if let Some(lib_dir) = &python_config.lib_dir { println!("cargo::rustc-link-search=native={}", lib_dir); - // Set cargo metadata to inform dependent binaries about how to set their - // RPATH (see controller/build.rs for an example). println!("cargo::metadata=LIB_PATH={}", lib_dir); } // Write the bindings to the $OUT_DIR/bindings.rs file. - let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()); builder .generate() .expect("Unable to generate bindings") .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings!"); - println!("cargo::rustc-link-lib=nccl"); + // Link appropriate library + if is_rocm { + // RCCL is ROCm's NCCL-compatible library + println!("cargo::rustc-link-lib=rccl"); + println!("cargo::rustc-link-search=native={}/lib", compute_home); + } else { + println!("cargo::rustc-link-lib=nccl"); + } + println!("cargo::rustc-cfg=cargo"); println!("cargo::rustc-check-cfg=cfg(cargo)"); } diff --git a/nccl-sys/src/lib.rs b/nccl-sys/src/lib.rs index 76dacd19f..d85eb8b6d 100644 --- a/nccl-sys/src/lib.rs +++ b/nccl-sys/src/lib.rs @@ -10,11 +10,19 @@ use cxx::ExternType; use cxx::type_id; /// SAFETY: bindings +#[cfg(not(rocm))] unsafe impl ExternType for CUstream_st { type Id = type_id!("CUstream_st"); type Kind = cxx::kind::Opaque; } +/// SAFETY: bindings +#[cfg(rocm)] +unsafe impl ExternType for ihipStream_t { + type Id = type_id!("ihipStream_t"); + type Kind = cxx::kind::Opaque; +} + /// SAFETY: bindings /// Trivial because this is POD struct unsafe impl ExternType for ncclConfig_t { From 84bc60736e80744b27d585d3f64c95db3b660f05 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 11 Dec 2025 18:05:19 +0000 Subject: [PATCH 05/12] torch-sys-cuda now works with hipify_torch --- nccl-sys/src/lib.rs | 15 ++++ torch-sys-cuda/build.rs | 170 +++++++++++++++++++++++++++++++++++----- 2 files changed, 167 insertions(+), 18 deletions(-) diff --git a/nccl-sys/src/lib.rs b/nccl-sys/src/lib.rs index d85eb8b6d..e2fc9f17d 100644 --- a/nccl-sys/src/lib.rs +++ b/nccl-sys/src/lib.rs @@ -47,6 +47,7 @@ mod inner { use serde::Serialize; use serde::Serializer; use serde::ser::SerializeSeq; + #[cfg(cargo)] include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -91,6 +92,20 @@ mod inner { pub use inner::*; +// ============================================================================= +// ROCm/HIP Compatibility Aliases +// ============================================================================= +// These allow consumers (like torch-sys-cuda) to use CUDA names transparently on ROCm. + +#[cfg(rocm)] +pub use inner::hipError_t as cudaError_t; + +#[cfg(rocm)] +pub use inner::hipStream_t as cudaStream_t; + +#[cfg(rocm)] +pub use inner::hipSetDevice as cudaSetDevice; + #[cfg(test)] mod tests { use std::mem::MaybeUninit; diff --git a/torch-sys-cuda/build.rs b/torch-sys-cuda/build.rs index bf6928a74..5d7dfef75 100644 --- a/torch-sys-cuda/build.rs +++ b/torch-sys-cuda/build.rs @@ -6,24 +6,130 @@ * LICENSE file in the root directory of this source tree. */ -//! This build script locates CUDA libraries and headers for torch-sys-cuda, +//! This build script locates CUDA/ROCm libraries and headers for torch-sys-cuda, //! which provides CUDA-specific PyTorch functionality. It depends on the base //! torch-sys crate for core PyTorch integration. #![feature(exit_status_error)] +use std::env; +use std::fs; +use std::path::Path; use std::path::PathBuf; +use std::process::Command; -use build_utils::find_cuda_home; use cxx_build::CFG; #[cfg(target_os = "macos")] fn main() {} +/// Hipify the bridge sources for ROCm compatibility +fn hipify_sources( + python_interpreter: &Path, + src_dir: &Path, + hip_src_dir: &Path, +) -> Result> { + println!( + "cargo:warning=torch-sys-cuda: Copying sources from {} to {} for hipify...", + src_dir.display(), + hip_src_dir.display() + ); + fs::create_dir_all(hip_src_dir)?; + + // Copy bridge.h, bridge.cpp, and bridge.rs + for filename in &["bridge.h", "bridge.cpp", "bridge.rs"] { + let src_file = src_dir.join(filename); + let dest_file = hip_src_dir.join(filename); + if src_file.exists() { + fs::copy(&src_file, &dest_file)?; + println!("cargo:rerun-if-changed={}", src_file.display()); + } + } + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); + let project_root = manifest_dir + .parent() + .ok_or("Failed to find project root")?; + let hipify_script = project_root + .join("deps") + .join("hipify_torch") + .join("hipify_cli.py"); + + println!("cargo:warning=torch-sys-cuda: Running hipify_torch..."); + let hipify_output = Command::new(python_interpreter) + .arg(&hipify_script) + .arg("--project-directory") + .arg(hip_src_dir) + .arg("--v2") + .arg("--output-directory") + .arg(hip_src_dir) + .output()?; + + if !hipify_output.status.success() { + return Err(format!( + "hipify_cli.py failed: {}", + String::from_utf8_lossy(&hipify_output.stderr) + ) + .into()); + } + + // Modify bridge.cpp to include bridge_hip.h instead of the original header + let bridge_cpp_path = hip_src_dir.join("bridge.cpp"); + let bridge_cpp_content = fs::read_to_string(&bridge_cpp_path)?; + let modified_cpp = bridge_cpp_content.replace( + "#include \"monarch/torch-sys-cuda/src/bridge.h\"", + "#include \"bridge_hip.h\"", + ); + fs::write(&bridge_cpp_path, modified_cpp)?; + + // Modify bridge.rs to include bridge_hip.h instead of the original header + let bridge_rs_path = hip_src_dir.join("bridge.rs"); + let bridge_rs_content = fs::read_to_string(&bridge_rs_path)?; + let modified_rs = bridge_rs_content.replace( + "include!(\"monarch/torch-sys-cuda/src/bridge.h\")", + "include!(\"bridge_hip.h\")", + ); + fs::write(&bridge_rs_path, modified_rs)?; + + println!("cargo:warning=torch-sys-cuda: hipify complete"); + + // Return the path to the hipified bridge.rs + Ok(bridge_rs_path) +} + #[cfg(not(target_os = "macos"))] fn main() { + // Auto-detect ROCm vs CUDA using build_utils + let (is_rocm, compute_home) = + if let Ok(rocm_home) = build_utils::validate_rocm_installation() { + let version = build_utils::get_rocm_version(&rocm_home).unwrap_or((6, 0)); + println!( + "cargo:warning=torch-sys-cuda: Using ROCm {}.{} at {}", + version.0, version.1, rocm_home + ); + println!("cargo:rustc-cfg=rocm"); + if version.0 >= 7 { + println!("cargo:rustc-cfg=rocm_7_plus"); + } else { + println!("cargo:rustc-cfg=rocm_6_x"); + } + (true, rocm_home) + } else if let Ok(cuda_home) = build_utils::validate_cuda_installation() { + println!( + "cargo:warning=torch-sys-cuda: Using CUDA at {}", + cuda_home + ); + (false, cuda_home) + } else { + panic!("Neither CUDA nor ROCm installation found!"); + }; + + // Emit cfg check declarations + println!("cargo:rustc-check-cfg=cfg(rocm)"); + println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); + // Use PyO3's Python discovery to find the correct Python library paths - // This is more robust than hardcoding platform-specific paths let mut python_lib_dir: Option = None; let python_config = pyo3_build_config::get(); @@ -34,30 +140,58 @@ fn main() { } // On some platforms, we may need to explicitly link against Python - // PyO3 handles the complexity of determining when this is needed if let Some(lib_name) = &python_config.lib_name { println!("cargo::rustc-link-lib={}", lib_name); } - let cuda_home = PathBuf::from(find_cuda_home().expect("CUDA installation not found")); + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let src_dir = manifest_dir.join("src"); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - // Prefix includes with `monarch` to maintain consistency with fbcode - // folder structure + // Determine source files to compile + let (bridge_rs_path, bridge_cpp_path, include_dir) = if is_rocm { + let python_interpreter = build_utils::find_python_interpreter(); + let hip_src_dir = out_path.join("hipified_src"); + let hipified_bridge_rs = hipify_sources(&python_interpreter, &src_dir, &hip_src_dir) + .expect("Failed to hipify torch-sys-cuda sources"); + + (hipified_bridge_rs, hip_src_dir.join("bridge.cpp"), hip_src_dir) + } else { + (src_dir.join("bridge.rs"), src_dir.join("bridge.cpp"), src_dir.clone()) + }; + + // Prefix includes with `monarch` to maintain consistency with fbcode folder structure CFG.include_prefix = "monarch/torch-sys-cuda"; - let _builder = cxx_build::bridge("src/bridge.rs") - .file("src/bridge.cpp") + + let mut builder = cxx_build::bridge(&bridge_rs_path); + builder + .file(&bridge_cpp_path) .flag("-std=c++14") - .include(format!("{}/include", cuda_home.display())) + .include(format!("{}/include", compute_home)) + .include(&include_dir) // Suppress warnings, otherwise we get massive spew from libtorch - .flag_if_supported("-w") - .compile("torch-sys-cuda"); + .flag_if_supported("-w"); - // Configure CUDA-specific linking - println!("cargo::rustc-link-lib=cudart"); - println!( - "cargo::rustc-link-search=native={}/lib64", - cuda_home.display() - ); + // Add platform-specific defines + if is_rocm { + builder + .define("__HIP_PLATFORM_AMD__", "1") + .define("USE_ROCM", "1"); + } + + builder.compile("torch-sys-cuda"); + + // Configure platform-specific linking + if is_rocm { + // ROCm uses amdhip64 and rccl + println!("cargo::rustc-link-lib=amdhip64"); + println!("cargo::rustc-link-lib=rccl"); + println!("cargo::rustc-link-search=native={}/lib", compute_home); + } else { + // CUDA uses cudart + println!("cargo::rustc-link-lib=cudart"); + println!("cargo::rustc-link-search=native={}/lib64", compute_home); + } // Add Python library directory to rpath for runtime linking if let Some(python_lib_dir) = &python_lib_dir { From 0af3ccd04017aca2831e4f054fd3cfb420bf6103 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 11 Dec 2025 19:08:46 +0000 Subject: [PATCH 06/12] refactor code --- build_utils/src/lib.rs | 85 ++++++++++++ cuda-sys/build.rs | 285 ---------------------------------------- nccl-sys/build.rs | 76 ++--------- rdmaxcel-sys/build.rs | 56 ++++---- rdmaxcel-sys/src/lib.rs | 1 + torch-sys-cuda/build.rs | 69 ++++------ 6 files changed, 150 insertions(+), 422 deletions(-) delete mode 100644 cuda-sys/build.rs diff --git a/build_utils/src/lib.rs b/build_utils/src/lib.rs index 22ee6c58a..801ab8498 100644 --- a/build_utils/src/lib.rs +++ b/build_utils/src/lib.rs @@ -513,6 +513,91 @@ pub fn print_rocm_lib_error_help() { eprintln!("Example: export ROCM_LIB_DIR=/opt/rocm/lib"); } +/// Run hipify_torch to convert CUDA sources to HIP +/// +/// This function: +/// 1. Creates output_dir if needed +/// 2. Copies all source_files to output_dir +/// 3. Finds deps/hipify_torch/hipify_cli.py relative to project_root +/// 4. Runs hipify_torch with --v2 flag +/// +/// After this function returns, hipified files will be in output_dir with +/// "_hip" suffix (e.g., "bridge.h" becomes "bridge_hip.h"). +/// +/// # Arguments +/// * `project_root` - Path to the monarch project root (contains deps/hipify_torch) +/// * `source_files` - Files to copy and hipify +/// * `output_dir` - Directory where hipified files will be written +/// +/// # Returns +/// * `Ok(())` on success +/// * `Err(BuildError)` if hipify fails +pub fn run_hipify_torch( + project_root: &Path, + source_files: &[PathBuf], + output_dir: &Path, +) -> Result<(), BuildError> { + // Create output directory if needed + fs::create_dir_all(output_dir).map_err(|e| { + BuildError::PathNotFound(format!("Failed to create output directory: {}", e)) + })?; + + // Copy source files to output directory + for source_file in source_files { + let filename = source_file.file_name().ok_or_else(|| { + BuildError::PathNotFound(format!("Invalid source file path: {:?}", source_file)) + })?; + let dest = output_dir.join(filename); + fs::copy(source_file, &dest).map_err(|e| { + BuildError::CommandFailed(format!( + "Failed to copy {:?} to {:?}: {}", + source_file, dest, e + )) + })?; + println!("cargo:rerun-if-changed={}", source_file.display()); + } + + // Find hipify script + let hipify_script = project_root.join("deps/hipify_torch/hipify_cli.py"); + if !hipify_script.exists() { + return Err(BuildError::PathNotFound(format!( + "hipify_cli.py not found at {:?}", + hipify_script + ))); + } + + // Get Python interpreter (defined in this module) + let python = find_python_interpreter(); + + // Run hipify_torch + let output = Command::new(&python) + .arg(&hipify_script) + .arg("--project-directory") + .arg(output_dir) + .arg("--v2") + .arg("--output-directory") + .arg(output_dir) + .output() + .map_err(|e| BuildError::CommandFailed(format!("Failed to run hipify_torch: {}", e)))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + return Err(BuildError::CommandFailed(format!( + "hipify_torch failed:\nstdout: {}\nstderr: {}", + stdout, stderr + ))); + } + + println!( + "cargo:warning=Successfully hipified {} files to {:?}", + source_files.len(), + output_dir + ); + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/cuda-sys/build.rs b/cuda-sys/build.rs deleted file mode 100644 index 8a596d838..000000000 --- a/cuda-sys/build.rs +++ /dev/null @@ -1,285 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::env; -use std::fs; -use std::path::Path; -use std::path::PathBuf; -use std::process::Command; - -// --- HIPify Helper Functions (cuda-sys specific) --- - -/// Applies the required 'CUstream_st' typedef fix to the hipified header. -fn patch_hipified_header(hipified_file_path: &Path) -> Result<(), Box> { - println!("cargo:warning=Patching hipified header for CUstream_st typedef..."); - - let hip_typedef = "\n// HIP/ROCm Fix: Manually define CUstream_st for cxx bindings\ntypedef struct ihipStream_t CUstream_st;\n"; - - let original_content = fs::read_to_string(hipified_file_path)?; - let lines: Vec<&str> = original_content.lines().collect(); - let mut insert_index = 0; - - for (i, line) in lines.iter().enumerate() { - if !line.trim().starts_with("#include") - && !line.trim().is_empty() - && !line.trim().starts_with("//") - { - insert_index = i; - break; - } - if i == lines.len() - 1 { - insert_index = lines.len(); - } - } - - let mut new_content = String::new(); - for (i, line) in lines.iter().enumerate() { - if i == insert_index { - new_content.push_str(hip_typedef); - } - new_content.push_str(line); - new_content.push('\n'); - } - - fs::write( - hipified_file_path, - new_content.trim_end_matches('\n').as_bytes(), - )?; - - println!("cargo:warning=Successfully injected CUstream_st typedef."); - Ok(()) -} - -/// Runs `hipify_torch` on the source file. -/// Returns the path to the newly hipified header file. -fn hipify_source_header( - python_interpreter: &Path, - src_dir: &Path, - hip_src_dir: &Path, - file_name: &str, -) -> Result> { - println!( - "cargo:warning=Copying source header {} to {} for in-place hipify...", - file_name, - hip_src_dir.display() - ); - fs::create_dir_all(hip_src_dir)?; - - let src_file = src_dir.join(file_name); - let dest_file = hip_src_dir.join(file_name); - - if src_file.exists() { - fs::copy(&src_file, &dest_file)?; - println!("cargo:rerun-if-changed={}", src_file.display()); - } else { - return Err(format!("Source file {} not found", src_file.display()).into()); - } - - println!("cargo:warning=Running hipify_torch in-place on copied sources with --v2..."); - - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); - let project_root = manifest_dir - .parent() - .ok_or("Failed to find project root: manifest parent not found")?; - - let hipify_script = project_root - .join("deps") - .join("hipify_torch") - .join("hipify_cli.py"); - - if !hipify_script.exists() { - return Err(format!("hipify_cli.py not found at {}", hipify_script.display()).into()); - } - println!("cargo:rerun-if-changed={}", hipify_script.display()); - - let hipify_output = Command::new(python_interpreter) - .arg(&hipify_script) - .arg("--project-directory") - .arg(hip_src_dir) - .arg("--v2") - .output()?; - - if !hipify_output.status.success() { - return Err(format!( - "hipify_cli.py failed: {}", - String::from_utf8_lossy(&hipify_output.stderr) - ) - .into()); - } - - println!("cargo:warning=Successfully hipified {} source", file_name); - - // The hipified output file name is wrapper_hip.h - let hip_file = hip_src_dir.join("wrapper_hip.h"); - - if hip_file.exists() { - patch_hipified_header(&hip_file)?; - Ok(hip_file) - } else { - let fallback_file = hip_src_dir.join(file_name); - if fallback_file.exists() { - patch_hipified_header(&fallback_file)?; - Ok(fallback_file) - } else { - Err(format!( - "Hipified output file not found. Expected: {}", - hip_file.display() - ) - .into()) - } - } -} - -// --- Main Build Logic --- - -#[cfg(target_os = "macos")] -fn main() {} - -#[cfg(not(target_os = "macos"))] -fn main() { - const CUDA_HEADER_NAME: &str = "wrapper.h"; - - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); - let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); - - // Check if we are building for ROCm (HIP) - check ROCm first - let is_rocm = build_utils::find_rocm_home().is_some(); - - println!("cargo:rerun-if-env-changed=USE_ROCM"); - - let header_path; - let compute_lib_names; - let compute_config; - - if is_rocm { - println!("cargo:warning=Using HIP from ROCm installation"); - compute_lib_names = vec!["amdhip64"]; - - // HIPify the CUDA wrapper header - let hip_src_dir = out_dir.join("hipified_src"); - let python_interpreter = build_utils::find_python_interpreter(); - - header_path = hipify_source_header( - &python_interpreter, - &manifest_dir.join("src"), - &hip_src_dir, - CUDA_HEADER_NAME, - ) - .expect("Failed to hipify wrapper.h"); - - // Discover ROCm configuration - match build_utils::discover_rocm_config() { - Ok(config) => { - compute_config = build_utils::CudaConfig { - cuda_home: config.rocm_home, - include_dirs: config.include_dirs, - lib_dirs: config.lib_dirs, - } - } - Err(_) => { - build_utils::print_rocm_error_help(); - std::process::exit(1); - } - } - } else { - println!("cargo:warning=Using CUDA"); - compute_lib_names = vec!["cuda", "cudart"]; - header_path = manifest_dir.join("src").join(CUDA_HEADER_NAME); - - match build_utils::discover_cuda_config() { - Ok(config) => compute_config = config, - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); - } - } - } - - // Configure bindgen - let mut builder = bindgen::Builder::default() - .header(header_path.to_str().expect("Invalid header path")) - .clang_arg("-x") - .clang_arg("c++") - .clang_arg("-std=gnu++20") - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - .allowlist_function("cuda.*") - .allowlist_function("CUDA.*") - .allowlist_type("cuda.*") - .allowlist_type("CUDA.*") - .allowlist_type("CUstream_st") - .allowlist_function("hip.*") - .allowlist_type("hip.*") - .default_enum_style(bindgen::EnumVariation::NewType { - is_bitfield: false, - is_global: false, - }); - - for include_dir in &compute_config.include_dirs { - builder = builder.clang_arg(format!("-I{}", include_dir.display())); - } - - if is_rocm { - builder = builder - .clang_arg("-D__HIP_PLATFORM_AMD__=1") - .clang_arg("-DUSE_ROCM=1"); - } - - // Python environment - let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { - Ok(config) => config, - Err(_) => { - eprintln!("Warning: Failed to get Python environment directories"); - build_utils::PythonConfig { - include_dir: None, - lib_dir: None, - } - } - }; - - if let Some(include_dir) = &python_config.include_dir { - builder = builder.clang_arg(format!("-I{}", include_dir)); - } - if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo:rustc-link-search=native={}", lib_dir); - println!("cargo:metadata=LIB_PATH={}", lib_dir); - } - - // Link compute libraries - let compute_lib_dir = if is_rocm { - match build_utils::get_rocm_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_rocm_lib_error_help(); - std::process::exit(1); - } - } - } else { - match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); - } - } - }; - println!("cargo:rustc-link-search=native={}", compute_lib_dir); - for lib_name in compute_lib_names { - println!("cargo:rustc-link-lib={}", lib_name); - } - - // Generate bindings - let bindings = builder.generate().expect("Unable to generate bindings"); - - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings"); - - println!("cargo:rustc-cfg=cargo"); - println!("cargo:rustc-check-cfg=cfg(cargo)"); -} diff --git a/nccl-sys/build.rs b/nccl-sys/build.rs index 58bdb8f4a..3a270609c 100644 --- a/nccl-sys/build.rs +++ b/nccl-sys/build.rs @@ -7,68 +7,19 @@ */ use std::env; -use std::fs; -use std::path::Path; use std::path::PathBuf; -use std::process::Command; #[cfg(target_os = "macos")] fn main() {} -/// Hipify the nccl.h header for ROCm compatibility -fn hipify_sources( - python_interpreter: &Path, - src_dir: &Path, - hip_src_dir: &Path, -) -> Result<(), Box> { - println!( - "cargo:warning=nccl-sys: Copying sources from {} to {} for hipify...", - src_dir.display(), - hip_src_dir.display() - ); - fs::create_dir_all(hip_src_dir)?; - - // Copy header file - let src_file = src_dir.join("nccl.h"); - let dest_file = hip_src_dir.join("nccl.h"); - if src_file.exists() { - fs::copy(&src_file, &dest_file)?; - println!("cargo:rerun-if-changed={}", src_file.display()); - } - - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); - let project_root = manifest_dir - .parent() - .ok_or("Failed to find project root")?; - let hipify_script = project_root - .join("deps") - .join("hipify_torch") - .join("hipify_cli.py"); - - println!("cargo:warning=nccl-sys: Running hipify_torch..."); - let hipify_output = Command::new(python_interpreter) - .arg(&hipify_script) - .arg("--project-directory") - .arg(hip_src_dir) - .arg("--v2") - .arg("--output-directory") - .arg(hip_src_dir) - .output()?; - - if !hipify_output.status.success() { - return Err(format!( - "hipify_cli.py failed: {}", - String::from_utf8_lossy(&hipify_output.stderr) - ) - .into()); - } - - println!("cargo:warning=nccl-sys: hipify complete"); - Ok(()) -} - #[cfg(not(target_os = "macos"))] fn main() { + // Declare custom cfg options to avoid warnings + println!("cargo::rustc-check-cfg=cfg(cargo)"); + println!("cargo::rustc-check-cfg=cfg(rocm)"); + println!("cargo::rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); + // Auto-detect ROCm vs CUDA using build_utils let (is_rocm, compute_home, _rocm_version) = if let Ok(rocm_home) = build_utils::validate_rocm_installation() { @@ -95,22 +46,20 @@ fn main() { std::process::exit(1); }; - // Emit cfg check declarations - println!("cargo:rustc-check-cfg=cfg(rocm)"); - println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); - println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let src_dir = manifest_dir.join("src"); let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - let python_interpreter = build_utils::find_python_interpreter(); let compute_include_path = format!("{}/include", compute_home); // Determine which header to use let header_path = if is_rocm { - // Hipify the sources for ROCm + // Hipify the sources for ROCm using centralized build_utils function let hip_src_dir = out_path.join("hipified_src"); - hipify_sources(&python_interpreter, &src_dir, &hip_src_dir) + let project_root = manifest_dir.parent().expect("Failed to find project root"); + + let source_files = vec![src_dir.join("nccl.h")]; + + build_utils::run_hipify_torch(project_root, &source_files, &hip_src_dir) .expect("Failed to hipify nccl-sys sources"); // The hipified header should now include @@ -229,5 +178,4 @@ fn main() { } println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); } diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 58e7bdc05..bf8c2c69b 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -262,7 +262,7 @@ fn patch_hipified_files_rocm6(hip_src_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box> { - println!("cargo:warning=Copying sources from {} to {}...", src_dir.display(), hip_src_dir.display()); - fs::create_dir_all(hip_src_dir)?; + println!("cargo:warning=Hipifying sources from {} to {}...", src_dir.display(), hip_src_dir.display()); + // Collect source files to hipify let files_to_copy = [ "lib.rs", "rdmaxcel.h", "rdmaxcel.c", "rdmaxcel.cpp", "rdmaxcel.cu", "test_rdmaxcel.c", "driver_api.h", "driver_api.cpp", ]; - for file_name in files_to_copy { - let src_file = src_dir.join(file_name); - let dest_file = hip_src_dir.join(file_name); - if src_file.exists() { - fs::copy(&src_file, &dest_file)?; - println!("cargo:rerun-if-changed={}", src_file.display()); - } - } + let source_files: Vec = files_to_copy + .iter() + .map(|f| src_dir.join(f)) + .filter(|p| p.exists()) + .collect(); + + // Find project root let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); let project_root = manifest_dir.parent().ok_or("Failed to find project root")?; - let hipify_script = project_root.join("deps").join("hipify_torch").join("hipify_cli.py"); - - let hipify_output = Command::new(python_interpreter) - .arg(&hipify_script) - .arg("--project-directory").arg(hip_src_dir) - .arg("--v2") - .arg("--output-directory").arg(hip_src_dir) - .output()?; - if !hipify_output.status.success() { - return Err(format!("hipify_cli.py failed: {}", String::from_utf8_lossy(&hipify_output.stderr)).into()); - } + // Use centralized hipify function from build_utils + build_utils::run_hipify_torch(project_root, &source_files, hip_src_dir) + .map_err(|e| format!("hipify_torch failed: {}", e))?; + // Apply rdmaxcel-specific patches based on ROCm version let (major, _minor) = rocm_version; if major >= 7 { patch_hipified_files_rocm7(hip_src_dir)?; } else { patch_hipified_files_rocm6(hip_src_dir)?; } + Ok(()) } @@ -366,6 +360,11 @@ fn main() {} #[cfg(not(target_os = "macos"))] fn main() { + // Declare custom cfg options to avoid warnings + println!("cargo::rustc-check-cfg=cfg(cargo)"); + println!("cargo::rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); + println!("cargo:rustc-link-lib=ibverbs"); println!("cargo:rustc-link-lib=mlx5"); @@ -383,9 +382,6 @@ fn main() { std::process::exit(1); }; - println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); - println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); let src_dir = manifest_dir.join("src"); let python_interpreter = build_utils::find_python_interpreter(); @@ -433,7 +429,7 @@ fn main() { if is_rocm { let hip_src_dir = out_path.join("hipified_src"); - hipify_sources(&python_interpreter, &src_dir, &hip_src_dir, rocm_version).expect("Failed to hipify sources"); + hipify_sources(&src_dir, &hip_src_dir, rocm_version).expect("Failed to hipify sources"); validate_hipified_files(&hip_src_dir).expect("Hipified files validation failed"); code_dir = hip_src_dir.clone(); header_path = hip_src_dir.join("rdmaxcel_hip.h"); @@ -525,6 +521,10 @@ fn main() { if cpp_source_path.exists() { let mut cpp_build = cc::Build::new(); cpp_build.file(&cpp_source_path).include(&code_dir).flag("-fPIC").cpp(true).flag("-std=gnu++20").flag("-Wno-unused-parameter").define("PYTORCH_C10_DRIVER_API_SUPPORTED", "1"); + // Suppress deprecated API warnings for HIP context management APIs (deprecated in ROCm 6.x) + if is_rocm { + cpp_build.flag("-Wno-deprecated-declarations"); + } if driver_api_cpp_path.exists() { cpp_build.file(&driver_api_cpp_path); } cpp_build.include(&compute_include_path); if is_rocm { @@ -538,7 +538,7 @@ fn main() { // Compile CUDA/HIP files if cuda_source_path.exists() { - let (compiler_path, compiler_name) = if is_rocm { (format!("{}/bin/hipcc", compute_home), "hipcc") } else { (format!("{}/bin/nvcc", compute_home), "nvcc") }; + let compiler_path = if is_rocm { format!("{}/bin/hipcc", compute_home) } else { format!("{}/bin/nvcc", compute_home) }; let cuda_build_dir = format!("{}/target/cuda_build", manifest_dir.display()); std::fs::create_dir_all(&cuda_build_dir).expect("Failed to create CUDA build directory"); let cuda_obj_path = format!("{}/rdmaxcel_cuda.o", cuda_build_dir); diff --git a/rdmaxcel-sys/src/lib.rs b/rdmaxcel-sys/src/lib.rs index f774e30f7..34058558a 100644 --- a/rdmaxcel-sys/src/lib.rs +++ b/rdmaxcel-sys/src/lib.rs @@ -14,6 +14,7 @@ mod inner { #![allow(non_camel_case_types)] #![allow(non_snake_case)] #![allow(unused_attributes)] + #![allow(dead_code)] // bindgen generates many functions we may not use #[cfg(not(cargo))] use crate::ibv_wc_flags; #[cfg(not(cargo))] diff --git a/torch-sys-cuda/build.rs b/torch-sys-cuda/build.rs index 5d7dfef75..d6ae57bf5 100644 --- a/torch-sys-cuda/build.rs +++ b/torch-sys-cuda/build.rs @@ -16,63 +16,43 @@ use std::env; use std::fs; use std::path::Path; use std::path::PathBuf; -use std::process::Command; use cxx_build::CFG; #[cfg(target_os = "macos")] fn main() {} -/// Hipify the bridge sources for ROCm compatibility +/// Hipify the bridge sources for ROCm compatibility using build_utils, +/// then apply torch-sys-cuda specific patches. +/// Returns the path to the hipified bridge.rs file. fn hipify_sources( - python_interpreter: &Path, src_dir: &Path, hip_src_dir: &Path, ) -> Result> { println!( - "cargo:warning=torch-sys-cuda: Copying sources from {} to {} for hipify...", + "cargo:warning=torch-sys-cuda: Hipifying sources from {} to {}...", src_dir.display(), hip_src_dir.display() ); - fs::create_dir_all(hip_src_dir)?; - - // Copy bridge.h, bridge.cpp, and bridge.rs - for filename in &["bridge.h", "bridge.cpp", "bridge.rs"] { - let src_file = src_dir.join(filename); - let dest_file = hip_src_dir.join(filename); - if src_file.exists() { - fs::copy(&src_file, &dest_file)?; - println!("cargo:rerun-if-changed={}", src_file.display()); - } - } + // Find project root let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); let project_root = manifest_dir .parent() .ok_or("Failed to find project root")?; - let hipify_script = project_root - .join("deps") - .join("hipify_torch") - .join("hipify_cli.py"); - - println!("cargo:warning=torch-sys-cuda: Running hipify_torch..."); - let hipify_output = Command::new(python_interpreter) - .arg(&hipify_script) - .arg("--project-directory") - .arg(hip_src_dir) - .arg("--v2") - .arg("--output-directory") - .arg(hip_src_dir) - .output()?; - - if !hipify_output.status.success() { - return Err(format!( - "hipify_cli.py failed: {}", - String::from_utf8_lossy(&hipify_output.stderr) - ) - .into()); - } + // Collect source files to hipify + let source_files: Vec = ["bridge.h", "bridge.cpp", "bridge.rs"] + .iter() + .map(|f| src_dir.join(f)) + .filter(|p| p.exists()) + .collect(); + + // Use centralized hipify function from build_utils + build_utils::run_hipify_torch(project_root, &source_files, hip_src_dir) + .map_err(|e| format!("hipify_torch failed: {}", e))?; + + // Apply torch-sys-cuda specific patches: // Modify bridge.cpp to include bridge_hip.h instead of the original header let bridge_cpp_path = hip_src_dir.join("bridge.cpp"); let bridge_cpp_content = fs::read_to_string(&bridge_cpp_path)?; @@ -92,13 +72,18 @@ fn hipify_sources( fs::write(&bridge_rs_path, modified_rs)?; println!("cargo:warning=torch-sys-cuda: hipify complete"); - + // Return the path to the hipified bridge.rs Ok(bridge_rs_path) } #[cfg(not(target_os = "macos"))] fn main() { + // Declare custom cfg options to avoid warnings + println!("cargo::rustc-check-cfg=cfg(rocm)"); + println!("cargo::rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); + // Auto-detect ROCm vs CUDA using build_utils let (is_rocm, compute_home) = if let Ok(rocm_home) = build_utils::validate_rocm_installation() { @@ -124,11 +109,6 @@ fn main() { panic!("Neither CUDA nor ROCm installation found!"); }; - // Emit cfg check declarations - println!("cargo:rustc-check-cfg=cfg(rocm)"); - println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); - println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); - // Use PyO3's Python discovery to find the correct Python library paths let mut python_lib_dir: Option = None; let python_config = pyo3_build_config::get(); @@ -150,9 +130,8 @@ fn main() { // Determine source files to compile let (bridge_rs_path, bridge_cpp_path, include_dir) = if is_rocm { - let python_interpreter = build_utils::find_python_interpreter(); let hip_src_dir = out_path.join("hipified_src"); - let hipified_bridge_rs = hipify_sources(&python_interpreter, &src_dir, &hip_src_dir) + let hipified_bridge_rs = hipify_sources(&src_dir, &hip_src_dir) .expect("Failed to hipify torch-sys-cuda sources"); (hipified_bridge_rs, hip_src_dir.join("bridge.cpp"), hip_src_dir) From 4748059417a6082c2daf948c67d3b0c5f9f98e57 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 11 Dec 2025 21:46:56 +0000 Subject: [PATCH 07/12] now working with the cpp static libs --- build_utils/src/lib.rs | 22 +++++++++++++-------- monarch_cpp_static_libs/build.rs | 27 ++++++++++++++++---------- monarch_rdma/build.rs | 9 ++------- monarch_rdma/src/rdma_manager_actor.rs | 2 +- rdmaxcel-sys/build.rs | 19 +++++++++++++++--- 5 files changed, 50 insertions(+), 29 deletions(-) diff --git a/build_utils/src/lib.rs b/build_utils/src/lib.rs index b167ca758..0ecac3669 100644 --- a/build_utils/src/lib.rs +++ b/build_utils/src/lib.rs @@ -108,7 +108,7 @@ impl std::error::Error for BuildError {} /// Get environment variable with cargo rerun notification pub fn get_env_var_with_rerun(name: &str) -> Result { - println!("cargo::rerun-if-env-changed={}", name); + println!("cargo:rerun-if-env-changed={}", name); env::var(name) } @@ -677,6 +677,7 @@ pub struct CppStaticLibsConfig { pub rdma_include: String, pub rdma_lib_dir: String, pub rdma_util_dir: String, + pub rdma_ccan_dir: String, } impl CppStaticLibsConfig { @@ -691,6 +692,8 @@ impl CppStaticLibsConfig { .expect("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_LIB_DIR not set - add monarch_cpp_static_libs as build-dependency"), rdma_util_dir: std::env::var("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_UTIL_DIR") .expect("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_UTIL_DIR not set - add monarch_cpp_static_libs as build-dependency"), + rdma_ccan_dir: std::env::var("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_CCAN_DIR") + .expect("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_CCAN_DIR not set - add monarch_cpp_static_libs as build-dependency"), } } @@ -700,19 +703,22 @@ impl CppStaticLibsConfig { /// - libmlx5.a /// - libibverbs.a /// - librdma_util.a + /// - libccan.a pub fn emit_link_directives(&self) { // Emit link search paths - println!("cargo::rustc-link-search=native={}", self.rdma_lib_dir); - println!("cargo::rustc-link-search=native={}", self.rdma_util_dir); + println!("cargo:rustc-link-search=native={}", self.rdma_lib_dir); + println!("cargo:rustc-link-search=native={}", self.rdma_util_dir); + println!("cargo:rustc-link-search=native={}", self.rdma_ccan_dir); // Use whole-archive for rdma-core static libraries - println!("cargo::rustc-link-arg=-Wl,--whole-archive"); - println!("cargo::rustc-link-lib=static=mlx5"); - println!("cargo::rustc-link-lib=static=ibverbs"); - println!("cargo::rustc-link-arg=-Wl,--no-whole-archive"); + println!("cargo:rustc-link-arg=-Wl,--whole-archive"); + println!("cargo:rustc-link-lib=static=mlx5"); + println!("cargo:rustc-link-lib=static=ibverbs"); + println!("cargo:rustc-link-arg=-Wl,--no-whole-archive"); // rdma_util helper library - println!("cargo::rustc-link-lib=static=rdma_util"); + println!("cargo:rustc-link-lib=static=rdma_util"); + println!("cargo:rustc-link-lib=static=ccan"); } } diff --git a/monarch_cpp_static_libs/build.rs b/monarch_cpp_static_libs/build.rs index fdce92f75..f3ce46d1e 100644 --- a/monarch_cpp_static_libs/build.rs +++ b/monarch_cpp_static_libs/build.rs @@ -10,7 +10,7 @@ //! //! This build script: //! 1. Obtains rdma-core source (from MONARCH_RDMA_CORE_SRC or by cloning) -//! 2. Builds rdma-core with static libraries (libibverbs.a, libmlx5.a) +//! 2. Builds rdma-core with static libraries (libibverbs.a, libmlx5.a, libccan.a) //! 3. Emits link directives for downstream crates use std::path::Path; @@ -144,8 +144,10 @@ fn copy_dir(src_dir: &Path, target_dir: &Path) { fn build_rdma_core(rdma_core_dir: &Path) -> PathBuf { let build_dir = rdma_core_dir.join("build"); - // Check if already built - if build_dir.join("lib/statics/libibverbs.a").exists() { + // Check if already built (Must include ccan check now) + if build_dir.join("lib/statics/libibverbs.a").exists() + && build_dir.join("ccan/libccan.a").exists() + { println!("cargo:warning=rdma-core already built"); return build_dir; } @@ -208,12 +210,12 @@ fn build_rdma_core(rdma_core_dir: &Path) -> PathBuf { panic!("Failed to configure rdma-core with cmake"); } - // Build only the targets we need: libibverbs.a, libmlx5.a, and librdma_util.a - // We don't need librdmacm which has build issues with long paths + // Build targets: ADDED ccan/libccan.a let targets = [ "lib/statics/libibverbs.a", "lib/statics/libmlx5.a", "util/librdma_util.a", + "ccan/libccan.a", ]; for target in &targets { @@ -246,6 +248,7 @@ fn build_rdma_core(rdma_core_dir: &Path) -> PathBuf { fn emit_link_directives(rdma_build_dir: &Path) { let rdma_static_dir = rdma_build_dir.join("lib/statics"); let rdma_util_dir = rdma_build_dir.join("util"); + let rdma_ccan_dir = rdma_build_dir.join("ccan"); // Emit search paths println!( @@ -253,6 +256,7 @@ fn emit_link_directives(rdma_build_dir: &Path) { rdma_static_dir.display() ); println!("cargo:rustc-link-search=native={}", rdma_util_dir.display()); + println!("cargo:rustc-link-search=native={}", rdma_ccan_dir.display()); // Static libraries - use whole-archive for rdma-core static libraries println!("cargo:rustc-link-arg=-Wl,--whole-archive"); @@ -260,17 +264,20 @@ fn emit_link_directives(rdma_build_dir: &Path) { println!("cargo:rustc-link-lib=static=ibverbs"); println!("cargo:rustc-link-arg=-Wl,--no-whole-archive"); - // rdma_util helper library + // Helper libraries println!("cargo:rustc-link-lib=static=rdma_util"); + println!("cargo:rustc-link-lib=static=ccan"); // Export metadata for dependent crates - // Use cargo:: (double colon) format for proper DEP__ env vars + // UPDATED: Using single colon 'cargo:key=value' is more compatible with build scripts + // that read metadata via DEP_PKG_KEY env vars. println!( - "cargo::metadata=RDMA_INCLUDE={}", + "cargo:RDMA_INCLUDE={}", rdma_build_dir.join("include").display() ); - println!("cargo::metadata=RDMA_LIB_DIR={}", rdma_static_dir.display()); - println!("cargo::metadata=RDMA_UTIL_DIR={}", rdma_util_dir.display()); + println!("cargo:RDMA_LIB_DIR={}", rdma_static_dir.display()); + println!("cargo:RDMA_UTIL_DIR={}", rdma_util_dir.display()); + println!("cargo:RDMA_CCAN_DIR={}", rdma_ccan_dir.display()); // Re-run if build scripts change println!("cargo:rerun-if-changed=build.rs"); diff --git a/monarch_rdma/build.rs b/monarch_rdma/build.rs index dfdb25e53..0c1c6b1cf 100644 --- a/monarch_rdma/build.rs +++ b/monarch_rdma/build.rs @@ -77,13 +77,8 @@ fn main() { } } } else { - match build_utils::get_cuda_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_cuda_lib_error_help(); - std::process::exit(1); - } - } + // get_cuda_lib_dir() returns String directly and panics on failure + build_utils::get_cuda_lib_dir() }; println!("cargo:rustc-link-search=native={}", compute_lib_dir); diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 50f9e3b1e..c55c3362d 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -361,7 +361,7 @@ impl RdmaManagerActor { // Use rdmaxcel utility to get PCI address from CUDA pointer let mut pci_addr_buf: [std::os::raw::c_char; 16] = [0; 16]; // Enough space for "ffff:ff:ff.0\0" let err = rdmaxcel_sys::get_cuda_pci_address_from_ptr( - addr as u64, + ptr, pci_addr_buf.as_mut_ptr(), pci_addr_buf.len(), ); diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index e6fa85e1a..3d8dda455 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -380,12 +380,25 @@ fn main() { let cpp_static_libs_config = try_get_cpp_static_libs_config(); let rdma_include = cpp_static_libs_config.as_ref().map(|c| c.rdma_include.clone()); - // If we don't have static libs config, use dynamic linking for ibverbs/mlx5 - if cpp_static_libs_config.is_none() { + if let Some(config) = &cpp_static_libs_config { + // Explicitly emit link directives from the config if it was found. + // This ensures ccan, rdma_util, etc., are linked. + config.emit_link_directives(); + } else { + // Fallback: If metadata failed, check if we should link statically anyway. + // If monarch_cpp_static_libs ran (which it did, per logs), it emitted -L flags + // that cargo propagates automatically. We just need to ensure the libraries + // are on the link list. + + // Link main libs (could be static or dynamic, but if -L is present, static wins) println!("cargo:rustc-link-lib=ibverbs"); println!("cargo:rustc-link-lib=mlx5"); + + // FORCE link helpers: ccan and rdma_util. + // These are required by the static version of libmlx5.a/libibverbs.a + println!("cargo:rustc-link-lib=static=rdma_util"); + println!("cargo:rustc-link-lib=static=ccan"); } - // Note: If cpp_static_libs_config is Some, link directives are emitted by monarch_extension let (is_rocm, compute_home, compute_lib_names, rocm_version) = if let Ok(rocm_home) = build_utils::validate_rocm_installation() { From e3dbb415f07eaf3c9de191f9223fac508bbb2541 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Fri, 12 Dec 2025 18:14:03 +0000 Subject: [PATCH 08/12] test gpu CI unit tests pass locally now --- hyperactor/src/proc.rs | 1 + hyperactor_mesh/src/systemd.rs | 24 ++++++++++++++---------- nccl-sys/build.rs | 21 +++++++++++++++++---- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index ed62cc87e..ad64dd33b 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -2974,6 +2974,7 @@ mod tests { assert!(!root_state.load(Ordering::SeqCst)); assert!(root_1_state.load(Ordering::SeqCst)); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; assert!(!root_1_1_state.load(Ordering::SeqCst)); assert!(!root_1_1_1_state.load(Ordering::SeqCst)); assert!(!root_2_state.load(Ordering::SeqCst)); diff --git a/hyperactor_mesh/src/systemd.rs b/hyperactor_mesh/src/systemd.rs index ddcfc686f..fa882c879 100644 --- a/hyperactor_mesh/src/systemd.rs +++ b/hyperactor_mesh/src/systemd.rs @@ -227,12 +227,12 @@ mod tests { ); // Poll for unit cleanup. - for attempt in 1..=5 { + for attempt in 1..=30 { RealClock.sleep(Duration::from_secs(1)).await; if systemd.get_unit(unit_name).await.is_err() { break; } - if attempt == 5 { + if attempt == 30 { panic!("transient unit not cleaned up after {} seconds", attempt); } } @@ -355,7 +355,9 @@ mod tests { } } } - else => break, + else => { + break; + }, } } }); @@ -372,13 +374,13 @@ mod tests { ); // Poll for unit cleanup. - for attempt in 1..=5 { + for attempt in 1..=30 { RealClock.sleep(Duration::from_secs(1)).await; if systemd.get_unit(unit_name).await.is_err() { states.lock().unwrap().push(UnitState::Gone); break; } - if attempt == 10 { + if attempt == 30 { panic!("transient unit not cleaned up after {} seconds", attempt); } } @@ -403,10 +405,12 @@ mod tests { .iter() .any(|s| matches!(s, UnitState::Gone)); - assert!(has_active, "Should observe active"); + assert!(has_active, "Should observe active state"); + // Accept Gone as valid proof of shutdown - on fast systems the unit + // may be garbage collected before we observe intermediate states assert!( - has_deactivating || has_inactive, - "Should observe deactivating or inactive state during shutdown" + has_deactivating || has_inactive || has_gone, + "Should observe deactivating, inactive, or gone state during shutdown. States: {:?}", collected_states ); assert!(has_gone, "Should observe unit cleanup"); @@ -423,7 +427,7 @@ mod tests { /// NOTE: I've been unable to make this work on Meta devgpu/devvm /// infrastructure due to journal configuration/permission quirks /// (for a starting point on this goto - /// https://fb.workplace.com/groups/systemd.and.friends/permalink/3781106268771810/). + /// [https://fb.workplace.com/groups/systemd.and.friends/permalink/3781106268771810/](https://fb.workplace.com/groups/systemd.and.friends/permalink/3781106268771810/)). /// See the `test_tail_unit_logs_via_fd*` tests for a working /// alternative that uses FD-passing instead of journald. /// @@ -476,7 +480,7 @@ mod tests { .open()?; // Per - // https://www.internalfb.com/wiki/Development_Environment/Debugging_systemd_Services/#examples + // [https://www.internalfb.com/wiki/Development_Environment/Debugging_systemd_Services/#examples](https://www.internalfb.com/wiki/Development_Environment/Debugging_systemd_Services/#examples) // we are setting up for the equivalent of // `journalctl _UID=$(id -u $USER) _SYSTEMD_USER_UNIT=test-tail-logs.service -f` // but (on Meta infra) that needs to be run under `sudo` diff --git a/nccl-sys/build.rs b/nccl-sys/build.rs index f123f908c..8df14c9d0 100644 --- a/nccl-sys/build.rs +++ b/nccl-sys/build.rs @@ -21,7 +21,7 @@ fn main() { println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); // Auto-detect ROCm vs CUDA using build_utils - let (is_rocm, compute_home, _rocm_version) = + let (is_rocm, compute_home, rocm_version) = if let Ok(rocm_home) = build_utils::validate_rocm_installation() { let version = build_utils::get_rocm_version(&rocm_home).unwrap_or((6, 0)); println!( @@ -72,7 +72,7 @@ fn main() { if bridge_cpp_hipified.exists() { let content = std::fs::read_to_string(&bridge_cpp_hipified) .expect("Failed to read hipified bridge.cpp"); - let patched = content + let mut patched = content // Fix include path to use hipified header .replace("#include \"bridge.h\"", "#include \"bridge_hip.h\"") // Replace all cudaStream_t with hipStream_t @@ -80,6 +80,14 @@ fn main() { // Patch dlopen library name for RCCL .replace("libnccl.so", "librccl.so") .replace("libnccl.so.2", "librccl.so"); + + if rocm_version.0 < 7 { + patched = patched.replace( + "LOOKUP_NCCL_ENTRY(ncclCommInitRankScalable)", + "// LOOKUP_NCCL_ENTRY(ncclCommInitRankScalable)" + ); + } + std::fs::write(&bridge_cpp_hipified, patched) .expect("Failed to write patched bridge.cpp"); } @@ -132,8 +140,13 @@ fn main() { // Communicator creation and management .allowlist_function("ncclCommInitRank") .allowlist_function("ncclCommInitAll") - .allowlist_function("ncclCommInitRankConfig") - .allowlist_function("ncclCommInitRankScalable") + .allowlist_function("ncclCommInitRankConfig"); + // It is missing in ROCm 6.x (RCCL based on NCCL < 2.20) + if !is_rocm || rocm_version.0 >= 7 { + builder = builder.allowlist_function("ncclCommInitRankScalable"); + } + + builder = builder .allowlist_function("ncclCommSplit") .allowlist_function("ncclCommFinalize") .allowlist_function("ncclCommDestroy") From 236b5cd08151674888e8d285911b06fee794d11e Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Tue, 16 Dec 2025 16:18:33 -0800 Subject: [PATCH 09/12] exclude rdmaxcel-sys from linking against libtorch Signed-off-by: Eli Uriegas --- .github/workflows/test-gpu-rust.yml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test-gpu-rust.yml b/.github/workflows/test-gpu-rust.yml index b30c59149..5cb051f7d 100644 --- a/.github/workflows/test-gpu-rust.yml +++ b/.github/workflows/test-gpu-rust.yml @@ -60,15 +60,17 @@ jobs: # internal buck test behavior. # The CI profile is configured in .config/nextest.toml # Exclude filter is for packages that don't build in Github Actions yet. - # * monarch_messages: monarch/target/debug/deps/monarch_messages-...: - # /lib64/libm.so.6: version `GLIBC_2.29' not found - # (required by /meta-pytorch/monarch/libtorch/lib/libtorch_cpu.so) + # Exclude packages that link against libtorch - the nightly libtorch requires + # GLIBC_2.29+ which is not available in the CI container. + # Error: /lib64/libm.so.6: version `GLIBC_2.29' not found + # (required by /meta-pytorch/monarch/libtorch/lib/libtorch_cpu.so) timeout 12m cargo nextest run --workspace --profile ci \ --exclude monarch_messages \ --exclude monarch_tensor_worker \ --exclude torch-sys-cuda \ --exclude monarch_rdma \ - --exclude torch-sys + --exclude torch-sys \ + --exclude rdmaxcel-sys # Copy the test results to the expected location # TODO: error in pytest-results-action, TypeError: results.testsuites.testsuite.testcase is not iterable # Don't try to parse these results for now. From ec7472d03a69967fe4f496b03a607f9b6b9ab29e Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Thu, 18 Dec 2025 21:37:47 +0000 Subject: [PATCH 10/12] more aligned with cuda static linking strategy --- build_utils/Cargo.toml | 2 +- build_utils/src/lib.rs | 2 + build_utils/src/rocm.rs | 462 +++++++++++++++++++ hyperactor/src/proc.rs | 1 - monarch_rdma/build.rs | 193 -------- rdmaxcel-sys/build.rs | 975 ++++++++++++++++------------------------ 6 files changed, 843 insertions(+), 792 deletions(-) create mode 100644 build_utils/src/rocm.rs delete mode 100644 monarch_rdma/build.rs diff --git a/build_utils/Cargo.toml b/build_utils/Cargo.toml index 084627227..ceb524a7d 100644 --- a/build_utils/Cargo.toml +++ b/build_utils/Cargo.toml @@ -14,5 +14,5 @@ doctest = false [dependencies] cc = "1.2.10" glob = "0.3.2" -pyo3-build-config = "0.24.2" +pyo3-build-config = { version = "0.24.2", features = ["resolve-config"] } which = "4.2.4" diff --git a/build_utils/src/lib.rs b/build_utils/src/lib.rs index 0fe8febf3..82c55c7cf 100644 --- a/build_utils/src/lib.rs +++ b/build_utils/src/lib.rs @@ -21,6 +21,8 @@ use std::process::Command; use glob::glob; use which::which; +pub mod rocm; + /// Python script to extract Python paths from sysconfig pub const PYTHON_PRINT_DIRS: &str = r" import sysconfig diff --git a/build_utils/src/rocm.rs b/build_utils/src/rocm.rs new file mode 100644 index 000000000..ab0b52d44 --- /dev/null +++ b/build_utils/src/rocm.rs @@ -0,0 +1,462 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! ROCm-specific build utilities for patching hipified CUDA code. +//! +//! This module provides functions to patch hipified source files for +//! compatibility with different ROCm versions: +//! - [`patch_hipified_files_rocm7`] for ROCm 7.0+ (native hipMemGetHandleForAddressRange) +//! - [`patch_hipified_files_rocm6`] for ROCm 6.x (uses HSA hsa_amd_portable_export_dmabuf) + +use std::fs; +use std::path::Path; + +// ============================================================================= +// Replacement tables - reduces duplication and makes patches easier to maintain +// ============================================================================= + +/// CUDA → HIP wrapper function name mappings +const WRAPPER_REPLACEMENTS: &[(&str, &str)] = &[ + ("rdmaxcel_cuMemGetAllocationGranularity", "rdmaxcel_hipMemGetAllocationGranularity"), + ("rdmaxcel_cuMemCreate", "rdmaxcel_hipMemCreate"), + ("rdmaxcel_cuMemAddressReserve", "rdmaxcel_hipMemAddressReserve"), + ("rdmaxcel_cuMemMap", "rdmaxcel_hipMemMap"), + ("rdmaxcel_cuMemSetAccess", "rdmaxcel_hipMemSetAccess"), + ("rdmaxcel_cuMemUnmap", "rdmaxcel_hipMemUnmap"), + ("rdmaxcel_cuMemAddressFree", "rdmaxcel_hipMemAddressFree"), + ("rdmaxcel_cuMemRelease", "rdmaxcel_hipMemRelease"), + ("rdmaxcel_cuMemcpyHtoD_v2", "rdmaxcel_hipMemcpyHtoD"), + ("rdmaxcel_cuMemcpyDtoH_v2", "rdmaxcel_hipMemcpyDtoH"), + ("rdmaxcel_cuMemsetD8_v2", "rdmaxcel_hipMemsetD8"), + ("rdmaxcel_cuPointerGetAttribute", "rdmaxcel_hipPointerGetAttribute"), + ("rdmaxcel_cuInit", "rdmaxcel_hipInit"), + ("rdmaxcel_cuDeviceGetCount", "rdmaxcel_hipDeviceGetCount"), + ("rdmaxcel_cuDeviceGetAttribute", "rdmaxcel_hipDeviceGetAttribute"), + ("rdmaxcel_cuDeviceGet", "rdmaxcel_hipDeviceGet"), + ("rdmaxcel_cuCtxCreate_v2", "rdmaxcel_hipCtxCreate"), + ("rdmaxcel_cuCtxSetCurrent", "rdmaxcel_hipCtxSetCurrent"), + ("rdmaxcel_cuGetErrorString", "rdmaxcel_hipGetErrorString"), +]; + +/// Driver API function pointer replacements (->func_() calls) +const DRIVER_API_PTR_REPLACEMENTS: &[(&str, &str)] = &[ + ("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_("), + ("->cuMemCreate_(", "->hipMemCreate_("), + ("->cuMemAddressReserve_(", "->hipMemAddressReserve_("), + ("->cuMemMap_(", "->hipMemMap_("), + ("->cuMemSetAccess_(", "->hipMemSetAccess_("), + ("->cuMemUnmap_(", "->hipMemUnmap_("), + ("->cuMemAddressFree_(", "->hipMemAddressFree_("), + ("->cuMemRelease_(", "->hipMemRelease_("), + ("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_("), + ("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_("), + ("->cuMemsetD8_v2_(", "->hipMemsetD8_("), + ("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_("), + ("->cuInit_(", "->hipInit_("), + ("->cuDeviceGet_(", "->hipDeviceGet_("), + ("->cuDeviceGetCount_(", "->hipGetDeviceCount_("), + ("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_("), + ("->cuCtxCreate_v2_(", "->hipCtxCreate_("), + ("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_("), + ("->cuCtxSynchronize_(", "->hipCtxSynchronize_("), + ("->cuGetErrorString_(", "->hipDrvGetErrorString_("), +]; + +/// Driver API macro replacements (_(func) entries) +const DRIVER_API_MACRO_REPLACEMENTS: &[(&str, &str)] = &[ + ("_(cuMemGetAllocationGranularity)", "_(hipMemGetAllocationGranularity)"), + ("_(cuMemCreate)", "_(hipMemCreate)"), + ("_(cuMemAddressReserve)", "_(hipMemAddressReserve)"), + ("_(cuMemMap)", "_(hipMemMap)"), + ("_(cuMemSetAccess)", "_(hipMemSetAccess)"), + ("_(cuMemUnmap)", "_(hipMemUnmap)"), + ("_(cuMemAddressFree)", "_(hipMemAddressFree)"), + ("_(cuMemRelease)", "_(hipMemRelease)"), + ("_(cuMemcpyHtoD_v2)", "_(hipMemcpyHtoD)"), + ("_(cuMemcpyDtoH_v2)", "_(hipMemcpyDtoH)"), + ("_(cuMemsetD8_v2)", "_(hipMemsetD8)"), + ("_(cuPointerGetAttribute)", "_(hipPointerGetAttribute)"), + ("_(cuInit)", "_(hipInit)"), + ("_(cuDeviceGet)", "_(hipDeviceGet)"), + ("_(cuDeviceGetCount)", "_(hipGetDeviceCount)"), + ("_(cuDeviceGetAttribute)", "_(hipDeviceGetAttribute)"), + ("_(cuCtxCreate_v2)", "_(hipCtxCreate)"), + ("_(cuCtxSetCurrent)", "_(hipCtxSetCurrent)"), + ("_(cuCtxSynchronize)", "_(hipCtxSynchronize)"), + ("_(cuGetErrorString)", "_(hipDrvGetErrorString)"), +]; + +// ============================================================================= +// Public API +// ============================================================================= + +/// Apply ROCm 7+ specific patches to hipified files. +/// +/// ROCm 7+ has native `hipMemGetHandleForAddressRange` support, so we can +/// use a straightforward CUDA→HIP mapping. +pub fn patch_hipified_files_rocm7(hip_src_dir: &Path) -> Result<(), Box> { + println!("cargo:warning=Patching hipified sources for ROCm 7.0+..."); + + patch_file(hip_src_dir, "rdmaxcel_hip.cpp", patch_rdmaxcel_cpp_common)?; + patch_file(hip_src_dir, "rdmaxcel_hip.h", patch_rdmaxcel_h)?; + patch_file(hip_src_dir, "driver_api_hip.h", |c| patch_driver_api_h_rocm7(&c))?; + patch_file(hip_src_dir, "driver_api_hip.cpp", |c| patch_driver_api_cpp_rocm7(&c))?; + + Ok(()) +} + +/// Apply ROCm 6.x specific patches to hipified files. +/// +/// ROCm 6.x does not have `hipMemGetHandleForAddressRange`, so we use +/// HSA's `hsa_amd_portable_export_dmabuf` instead. This requires: +/// - Adding HSA headers +/// - Replacing the handle function with HSA equivalent +/// - Using dlopen for HSA functions to avoid link-time dependency +pub fn patch_hipified_files_rocm6(hip_src_dir: &Path) -> Result<(), Box> { + println!("cargo:warning=Patching hipified sources for ROCm 6.x (HSA dmabuf)..."); + + patch_file(hip_src_dir, "rdmaxcel_hip.cpp", |c| patch_rdmaxcel_cpp_rocm6(&c))?; + patch_file(hip_src_dir, "rdmaxcel_hip.h", patch_rdmaxcel_h)?; + patch_file(hip_src_dir, "driver_api_hip.h", |c| patch_driver_api_h_rocm6(&c))?; + patch_file(hip_src_dir, "driver_api_hip.cpp", |c| { + let patched = patch_driver_api_cpp_rocm6(&c); + patch_for_dlopen(&patched) + })?; + + println!("cargo:warning=Applied dlopen patches for HIP/HSA functions"); + Ok(()) +} + +/// Validate that required hipified files exist after hipification. +pub fn validate_hipified_files(hip_src_dir: &Path) -> Result<(), Box> { + const REQUIRED: &[&str] = &["rdmaxcel_hip.h", "rdmaxcel_hip.c", "rdmaxcel_hip.cpp", "rdmaxcel.hip"]; + + for name in REQUIRED { + let path = hip_src_dir.join(name); + if !path.exists() { + return Err(format!( + "Required hipified file '{}' not found in {}", + name, + hip_src_dir.display() + ).into()); + } + } + Ok(()) +} + +// ============================================================================= +// Internal helpers +// ============================================================================= + +/// Read, transform, and write a file. No-op if file doesn't exist. +fn patch_file(dir: &Path, name: &str, f: F) -> Result<(), Box> +where + F: FnOnce(&str) -> String, +{ + let path = dir.join(name); + if path.exists() { + let content = fs::read_to_string(&path)?; + fs::write(&path, f(&content))?; + } + Ok(()) +} + +/// Apply a list of string replacements +fn apply_replacements(content: &str, replacements: &[(&str, &str)]) -> String { + let mut result = content.to_string(); + for (from, to) in replacements { + result = result.replace(from, to); + } + result +} + +// ============================================================================= +// Patch implementations - shared +// ============================================================================= + +fn patch_rdmaxcel_h(content: &str) -> String { + content + .replace("#include \"driver_api.h\"", "#include \"driver_api_hip.h\"") + .replace("CUdeviceptr", "hipDeviceptr_t") +} + +fn patch_rdmaxcel_cpp_common(content: &str) -> String { + content + .replace("#include ", + "#include \n#include ") + .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") + .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") + .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", + "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") + .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") + .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") + .replace("static_cast", "reinterpret_cast") + .replace("static_cast", "reinterpret_cast") + .replace("CUDA_SUCCESS", "hipSuccess") + .replace("CUresult", "hipError_t") +} + +// ============================================================================= +// ROCm 7+ specific patches +// ============================================================================= + +fn patch_driver_api_h_rocm7(content: &str) -> String { + let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); + result = result + .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") + .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); + result +} + +fn patch_driver_api_cpp_rocm7(content: &str) -> String { + let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); + result = apply_replacements(&result, DRIVER_API_PTR_REPLACEMENTS); + result = apply_replacements(&result, DRIVER_API_MACRO_REPLACEMENTS); + result + .replace("libcuda.so.1", "libamdhip64.so") + .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") + .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") + .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") + .replace("CUmemRangeHandleType", "hipMemRangeHandleType") +} + +// ============================================================================= +// ROCm 6.x specific patches (HSA dmabuf) +// ============================================================================= + +fn patch_rdmaxcel_cpp_rocm6(content: &str) -> String { + let mut result = content + .replace("#include ", + "#include \n#include \n#include \n#include ") + .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") + .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") + .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", + "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") + .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") + .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") + .replace("static_cast", "reinterpret_cast") + .replace("static_cast", "reinterpret_cast") + .replace("CUDA_SUCCESS", "hipSuccess") + .replace("CUdevice device", "hipDevice_t device") + .replace("cuDeviceGet(&device", "hipDeviceGet(&device") + .replace("cuDeviceGetAttribute", "hipDeviceGetAttribute") + .replace("cuPointerGetAttribute", "hipPointerGetAttribute") + .replace("CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", "hipDeviceAttributePciBusId") + .replace("CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", "hipDeviceAttributePciDeviceId") + .replace("CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", "hipDeviceAttributePciDomainID") + .replace("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") + .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "0 /* HSA dmabuf */"); + + // Convert cuMemGetHandleForAddressRange to HSA with correct argument order + result = result.replace("cuMemGetHandleForAddressRange(", "hsa_amd_portable_export_dmabuf("); + + // Fix argument order for HSA call (ptr, size, fd, flags) vs CUDA (fd, ptr, size, type, flags) + result = result.replace( + "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(start_addr),\n total_size,\n 0 /* HSA dmabuf */,\n 0);", + "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(start_addr),\n total_size,\n &fd,\n nullptr);" + ); + result = result.replace( + "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(chunk_start),\n chunk_size,\n 0 /* HSA dmabuf */,\n 0);", + "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(chunk_start),\n chunk_size,\n &fd,\n nullptr);" + ); + + result + .replace("CUresult cu_result", "hsa_status_t hsa_result") + .replace("hipError_t cu_result", "hsa_status_t hsa_result") + .replace("cu_result != hipSuccess", "hsa_result != HSA_STATUS_SUCCESS") + .replace("if (cu_result", "if (hsa_result") + .replace("hipPointerAttribute::device", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") +} + +fn patch_driver_api_h_rocm6(content: &str) -> String { + let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); + + // Add HSA includes if not present + if !result.contains("#include ") { + result = result.replace( + "#include ", + "#include \n#include \n#include " + ); + } + + // Replace CUDA handle function declaration with HSA version + let old_decl = "hipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n CUmemRangeHandleType handleType,\n unsigned long long flags);"; + let new_decl = "hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf(\n void* ptr,\n size_t size,\n int* fd,\n uint64_t* flags);"; + result = result.replace(old_decl, new_decl); + result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x placeholder */"); + + // Add CUDA-compatible wrapper declaration for existing callers + result.push_str(r#" + +// CUDA-compatible wrapper for code expecting cuMemGetHandleForAddressRange +hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + int handleType, + unsigned long long flags); +"#); + result +} + +fn patch_driver_api_cpp_rocm6(content: &str) -> String { + let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); + result = apply_replacements(&result, DRIVER_API_PTR_REPLACEMENTS); + result = apply_replacements(&result, DRIVER_API_MACRO_REPLACEMENTS); + + // Add HSA includes + result = result.replace( + "#include \"driver_api_hip.h\"", + "#include \"driver_api_hip.h\"\n#include \n#include " + ); + + result = result + .replace("libcuda.so.1", "libamdhip64.so") + .replace("dstDevice, srcHost, ByteCount);", "dstDevice, const_cast(srcHost), ByteCount);") + .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") + .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)"); + + // Replace the wrapper function with HSA version + let old_wrapper = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + CUmemRangeHandleType handleType, + unsigned long long flags) { + return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( + handle, dptr, size, handleType, flags); +}"#; + let hsa_impl = r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( + void* ptr, + size_t size, + int* fd, + uint64_t* flags) { + // Direct HSA call - will be replaced with dlopen version + return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); +}"#; + result = result.replace(old_wrapper, hsa_impl); + + // Handle alternate wrapper pattern + let old_wrapper2 = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + CUmemRangeHandleType handleType, + unsigned long long flags) { + return rdmaxcel::DriverAPI::get()->cuMemGetHandleForAddressRange_( + handle, dptr, size, handleType, flags); +}"#; + result = result.replace(old_wrapper2, hsa_impl); + + result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x placeholder */"); + + // Remove hipMemGetHandleForAddressRange from dlopen macro (we use HSA instead) + result = result.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange - using HSA */ \\" + ); + result = result.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange - using HSA */ \\" + ); + + // Add CUDA-compatible wrapper that translates to HSA + result.push_str(r#" + +// CUDA-compatible wrapper - translates cuMemGetHandleForAddressRange to HSA +hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + int handleType, + unsigned long long flags) { + (void)handleType; // ROCm 6.x only supports dmabuf + (void)flags; + hsa_status_t status = rdmaxcel_hsa_amd_portable_export_dmabuf( + reinterpret_cast(dptr), size, handle, nullptr); + return (status == HSA_STATUS_SUCCESS) ? hipSuccess : hipErrorUnknown; +} +"#); + result +} + +/// Apply dlopen patches to avoid link-time dependencies on HIP/HSA libraries. +/// +/// This is important because: +/// 1. hipFree must be called via dlopen'd pointer, not directly +/// 2. HSA functions must be loaded lazily to avoid libhsa-runtime64.so dependency +fn patch_for_dlopen(content: &str) -> String { + let mut result = content.to_string(); + + // 1. Add hipFree to the dlopen macro list + result = result.replace( + "_(hipDrvGetErrorString)", + "_(hipDrvGetErrorString) \\\n _(hipFree)" + ); + + // 2. Reorder DriverAPI::get() to create singleton first, then call dlopen'd hipFree + result = result.replace( + r#"DriverAPI* DriverAPI::get() { + // Ensure we have a valid CUDA context for this thread + hipFree(0); + static DriverAPI singleton = create_driver_api(); + return &singleton; +}"#, + r#"DriverAPI* DriverAPI::get() { + static DriverAPI singleton = create_driver_api(); + // Ensure valid HIP context via dlopen'd hipFree (not direct call) + singleton.hipFree_(0); + return &singleton; +}"# + ); + + // 3. Replace direct HSA call with lazy dlopen version + result = result.replace( + r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( + void* ptr, + size_t size, + int* fd, + uint64_t* flags) { + // Direct HSA call - will be replaced with dlopen version + return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); +}"#, + r#"// Lazy-loaded HSA function - avoids link-time dependency on libhsa-runtime64.so +static decltype(&hsa_amd_portable_export_dmabuf) g_hsa_export_dmabuf = nullptr; + +hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( + void* ptr, + size_t size, + int* fd, + uint64_t* flags) { + if (!g_hsa_export_dmabuf) { + // Try RTLD_NOLOAD first (library may already be loaded by HIP runtime) + void* handle = dlopen("libhsa-runtime64.so", RTLD_LAZY | RTLD_NOLOAD); + if (!handle) handle = dlopen("libhsa-runtime64.so", RTLD_LAZY); + if (!handle) { + throw std::runtime_error( + std::string("[RdmaXcel] Failed to load libhsa-runtime64.so: ") + dlerror()); + } + g_hsa_export_dmabuf = reinterpret_cast( + dlsym(handle, "hsa_amd_portable_export_dmabuf")); + if (!g_hsa_export_dmabuf) { + throw std::runtime_error( + std::string("[RdmaXcel] Symbol not found: hsa_amd_portable_export_dmabuf: ") + dlerror()); + } + } + return g_hsa_export_dmabuf(ptr, size, fd, flags); +}"# + ); + + // 4. Update the CUDA-compat wrapper to use our dlopen'd function + result = result.replace( + "hsa_status_t status = hsa_amd_portable_export_dmabuf(", + "hsa_status_t status = rdmaxcel_hsa_amd_portable_export_dmabuf(" + ); + + result +} diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index f1cb13609..30301a837 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -2995,7 +2995,6 @@ mod tests { assert!(!root_state.load(Ordering::SeqCst)); assert!(root_1_state.load(Ordering::SeqCst)); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; assert!(!root_1_1_state.load(Ordering::SeqCst)); assert!(!root_1_1_1_state.load(Ordering::SeqCst)); assert!(!root_2_state.load(Ordering::SeqCst)); diff --git a/monarch_rdma/build.rs b/monarch_rdma/build.rs deleted file mode 100644 index 0c1c6b1cf..000000000 --- a/monarch_rdma/build.rs +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#[cfg(target_os = "macos")] -fn main() {} - -#[cfg(not(target_os = "macos"))] -fn main() { - // Check if we are building for ROCm (HIP) - check ROCm first - let is_rocm = build_utils::find_rocm_home().is_some(); - - // Validate compute installation and set cfg flags - if is_rocm { - match build_utils::validate_rocm_installation() { - Ok(_) => println!("cargo:warning=Using ROCm/HIP for monarch_rdma"), - Err(_) => { - build_utils::print_rocm_error_help(); - std::process::exit(1); - } - } - - // Set ROCm version cfg flags - let rocm_version = build_utils::find_rocm_home() - .and_then(|home| build_utils::get_rocm_version(&home)) - .unwrap_or((6, 0)); - - if rocm_version.0 >= 7 { - println!("cargo:rustc-cfg=rocm_7_plus"); - } else { - println!("cargo:rustc-cfg=rocm_6_x"); - } - } else { - match build_utils::validate_cuda_installation() { - Ok(_) => println!("cargo:warning=Using CUDA for monarch_rdma"), - Err(_) => { - build_utils::print_cuda_error_help(); - std::process::exit(1); - } - } - } - - // Emit cfg check declarations - println!("cargo:rustc-check-cfg=cfg(rocm_6_x)"); - println!("cargo:rustc-check-cfg=cfg(rocm_7_plus)"); - - // Include headers and libs from the active environment. - let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { - Ok(config) => config, - Err(_) => { - eprintln!("Warning: Failed to get Python environment directories"); - build_utils::PythonConfig { - include_dir: None, - lib_dir: None, - } - } - }; - - if let Some(lib_dir) = &python_config.lib_dir { - println!("cargo:rustc-link-search=native={}", lib_dir); - // Set cargo metadata to inform dependent binaries about how to set their - // RPATH (see controller/build.rs for an example). - println!("cargo:metadata=LIB_PATH={}", lib_dir); - } - - // Get compute library directory and emit link directives - let compute_lib_dir = if is_rocm { - match build_utils::get_rocm_lib_dir() { - Ok(dir) => dir, - Err(_) => { - build_utils::print_rocm_lib_error_help(); - std::process::exit(1); - } - } - } else { - // get_cuda_lib_dir() returns String directly and panics on failure - build_utils::get_cuda_lib_dir() - }; - println!("cargo:rustc-link-search=native={}", compute_lib_dir); - - // Link compute libraries - if is_rocm { - println!("cargo:rustc-link-lib=amdhip64"); - println!("cargo:rustc-link-lib=hsa-runtime64"); - } else { - println!("cargo:rustc-link-lib=cuda"); - println!("cargo:rustc-link-lib=cudart"); - } - - // Link against the ibverbs and mlx5 libraries (used by rdmaxcel-sys) - println!("cargo:rustc-link-lib=ibverbs"); - println!("cargo:rustc-link-lib=mlx5"); - - // Link PyTorch libraries needed for C10 symbols used by rdmaxcel-sys - let use_pytorch_apis = build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS") - .unwrap_or_else(|_| "1".to_owned()); - if use_pytorch_apis == "1" { - // Get PyTorch library directory using build_utils - let python_interpreter = std::path::PathBuf::from("python"); - if let Ok(output) = std::process::Command::new(&python_interpreter) - .arg("-c") - .arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS) - .output() - { - if output.status.success() { - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { - // Add library search path - println!("cargo:rustc-link-search=native={}", path); - // Set rpath so runtime linker can find the libraries - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path); - } - } - } - } - - // Link core PyTorch libraries needed for C10 symbols - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); - if is_rocm { - println!("cargo:rustc-link-lib=c10_hip"); - } else { - println!("cargo:rustc-link-lib=c10_cuda"); - } - } else { - // Fallback to torch-sys links metadata if available - if let Ok(torch_lib_path) = std::env::var("DEP_TORCH_LIB_PATH") { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", torch_lib_path); - } - } - - // Set rpath for NCCL libraries if available - if let Ok(nccl_lib_path) = std::env::var("DEP_NCCL_LIB_PATH") { - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", nccl_lib_path); - } - - // Disable new dtags, as conda envs generally use `RPATH` over `RUNPATH` - println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags"); - - // Link the static libraries from rdmaxcel-sys - // Try the Cargo dependency mechanism first, then fall back to fixed paths - if let Ok(rdmaxcel_out_dir) = std::env::var("DEP_RDMAXCEL_SYS_OUT_DIR") { - println!("cargo:rustc-link-search=native={}", rdmaxcel_out_dir); - println!("cargo:rustc-link-lib=static=rdmaxcel"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cpp"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); - } else { - eprintln!("Warning: DEP_RDMAXCEL_SYS_OUT_DIR not found. Using fallback paths."); - - // Use relative paths to the known locations - let cuda_build_dir = "../rdmaxcel-sys/target/cuda_build"; - println!("cargo:rustc-link-search=native={}", cuda_build_dir); - println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); - - // Find the most recent rdmaxcel-sys build directory for C/C++ libraries - let monarch_target_dir = "../target/debug/build"; - if let Ok(entries) = std::fs::read_dir(monarch_target_dir) { - let mut rdmaxcel_dirs: Vec<_> = entries - .filter_map(|entry| entry.ok()) - .filter(|entry| { - entry - .file_name() - .to_string_lossy() - .starts_with("rdmaxcel-sys-") - }) - .collect(); - - // Sort by modification time and use the most recent - rdmaxcel_dirs - .sort_by_key(|entry| entry.metadata().ok().and_then(|m| m.modified().ok())); - - if let Some(most_recent) = rdmaxcel_dirs.last() { - let out_dir = most_recent.path().join("out"); - if out_dir.exists() { - println!("cargo:rustc-link-search=native={}", out_dir.display()); - println!("cargo:rustc-link-lib=static=rdmaxcel"); - println!("cargo:rustc-link-lib=static=rdmaxcel_cpp"); - } - } else { - eprintln!("Warning: No rdmaxcel-sys build directories found"); - } - } - } - - // Set build configuration flags - println!("cargo:rustc-cfg=cargo"); - println!("cargo:rustc-check-cfg=cfg(cargo)"); -} diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs index 3d8dda455..02adea70b 100644 --- a/rdmaxcel-sys/build.rs +++ b/rdmaxcel-sys/build.rs @@ -6,659 +6,440 @@ * LICENSE file in the root directory of this source tree. */ +//! Build script for rdmaxcel-sys +//! +//! Supports both CUDA and ROCm backends. ROCm support requires hipification +//! of CUDA sources and version-specific patches. + use std::env; -use std::fs; -use std::path::Path; use std::path::PathBuf; use std::process::Command; -// ============================================================================= -// Hipify Patching Functions (specific to rdmaxcel-sys) -// ============================================================================= +#[cfg(target_os = "macos")] +fn main() {} -fn rename_rdmaxcel_wrappers(content: &str) -> String { - content - .replace("rdmaxcel_cuMemGetAllocationGranularity", "rdmaxcel_hipMemGetAllocationGranularity") - .replace("rdmaxcel_cuMemCreate", "rdmaxcel_hipMemCreate") - .replace("rdmaxcel_cuMemAddressReserve", "rdmaxcel_hipMemAddressReserve") - .replace("rdmaxcel_cuMemMap", "rdmaxcel_hipMemMap") - .replace("rdmaxcel_cuMemSetAccess", "rdmaxcel_hipMemSetAccess") - .replace("rdmaxcel_cuMemUnmap", "rdmaxcel_hipMemUnmap") - .replace("rdmaxcel_cuMemAddressFree", "rdmaxcel_hipMemAddressFree") - .replace("rdmaxcel_cuMemRelease", "rdmaxcel_hipMemRelease") - .replace("rdmaxcel_cuMemcpyHtoD_v2", "rdmaxcel_hipMemcpyHtoD") - .replace("rdmaxcel_cuMemcpyDtoH_v2", "rdmaxcel_hipMemcpyDtoH") - .replace("rdmaxcel_cuMemsetD8_v2", "rdmaxcel_hipMemsetD8") - .replace("rdmaxcel_cuPointerGetAttribute", "rdmaxcel_hipPointerGetAttribute") - .replace("rdmaxcel_cuInit", "rdmaxcel_hipInit") - .replace("rdmaxcel_cuDeviceGetCount", "rdmaxcel_hipDeviceGetCount") - .replace("rdmaxcel_cuDeviceGetAttribute", "rdmaxcel_hipDeviceGetAttribute") - .replace("rdmaxcel_cuDeviceGet", "rdmaxcel_hipDeviceGet") - .replace("rdmaxcel_cuCtxCreate_v2", "rdmaxcel_hipCtxCreate") - .replace("rdmaxcel_cuCtxSetCurrent", "rdmaxcel_hipCtxSetCurrent") - .replace("rdmaxcel_cuGetErrorString", "rdmaxcel_hipGetErrorString") -} +#[cfg(not(target_os = "macos"))] +fn main() { + // Declare cfg flags + println!("cargo::rustc-check-cfg=cfg(cargo)"); + println!("cargo::rustc-check-cfg=cfg(rocm_6_x)"); + println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); -fn patch_hipified_files_rocm7(hip_src_dir: &Path) -> Result<(), Box> { - println!("cargo:warning=Patching hipify_torch output for ROCm 7.0+..."); - - let cpp_file = hip_src_dir.join("rdmaxcel_hip.cpp"); - if cpp_file.exists() { - let content = fs::read_to_string(&cpp_file)?; - let patched_content = content - .replace("#include ", "#include \n#include ") - .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") - .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") - .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") - .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") - .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") - .replace("static_cast", "reinterpret_cast") - .replace("static_cast", "reinterpret_cast") - .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "hipMemRangeHandleTypeDmaBufFd") - .replace("cuMemGetHandleForAddressRange", "hipMemGetHandleForAddressRange") - .replace("CUDA_SUCCESS", "hipSuccess") - .replace("CUresult", "hipError_t"); - fs::write(&cpp_file, patched_content)?; - } + let platform = detect_platform(); + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let src_dir = manifest_dir.join("src"); + + // Get RDMA includes from monarch_cpp_static_libs + let cpp_libs = build_utils::CppStaticLibsConfig::from_env(); - let header_file = hip_src_dir.join("rdmaxcel_hip.h"); - if header_file.exists() { - let content = fs::read_to_string(&header_file)?; - let patched_content = content - .replace("#include \"driver_api.h\"", "#include \"driver_api_hip.h\"") - .replace("CUdeviceptr", "hipDeviceptr_t"); - fs::write(&header_file, patched_content)?; - } + // Setup linking + println!("cargo:rustc-link-lib=dl"); + println!("cargo:rustc-link-search=native={}", platform.lib_dir()); + platform.emit_link_libs(); - let driver_api_h = hip_src_dir.join("driver_api_hip.h"); - if driver_api_h.exists() { - let content = fs::read_to_string(&driver_api_h)?; - let mut patched_content = rename_rdmaxcel_wrappers(&content); - patched_content = patched_content - .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") - .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); - fs::write(&driver_api_h, patched_content)?; + // Setup rerun triggers + for f in &["rdmaxcel.h", "rdmaxcel.c", "rdmaxcel.cpp", "rdmaxcel.cu", "driver_api.h", "driver_api.cpp"] { + println!("cargo:rerun-if-changed=src/{}", f); } - let driver_api_cpp = hip_src_dir.join("driver_api_hip.cpp"); - if driver_api_cpp.exists() { - let content = fs::read_to_string(&driver_api_cpp)?; - let mut patched_content = rename_rdmaxcel_wrappers(&content); - patched_content = patched_content - .replace("libcuda.so.1", "libamdhip64.so") - .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") - .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") - .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") - .replace("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_(") - .replace("->cuMemCreate_(", "->hipMemCreate_(") - .replace("->cuMemAddressReserve_(", "->hipMemAddressReserve_(") - .replace("->cuMemMap_(", "->hipMemMap_(") - .replace("->cuMemSetAccess_(", "->hipMemSetAccess_(") - .replace("->cuMemUnmap_(", "->hipMemUnmap_(") - .replace("->cuMemAddressFree_(", "->hipMemAddressFree_(") - .replace("->cuMemRelease_(", "->hipMemRelease_(") - .replace("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_(") - .replace("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_(") - .replace("->cuMemsetD8_v2_(", "->hipMemsetD8_(") - .replace("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_(") - .replace("->cuInit_(", "->hipInit_(") - .replace("->cuDeviceGet_(", "->hipDeviceGet_(") - .replace("->cuDeviceGetCount_(", "->hipGetDeviceCount_(") - .replace("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_(") - .replace("->cuCtxCreate_v2_(", "->hipCtxCreate_(") - .replace("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_(") - .replace("->cuCtxSynchronize_(", "->hipCtxSynchronize_(") - .replace("_(cuCtxSynchronize)", "_(hipCtxSynchronize)") - .replace("->cuGetErrorString_(", "->hipDrvGetErrorString_(") - .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); - fs::write(&driver_api_cpp, patched_content)?; + // Build + if let Ok(out_dir) = env::var("OUT_DIR") { + let out_path = PathBuf::from(&out_dir); + let sources = platform.prepare_sources(&src_dir, &out_path); + + let python_config = build_utils::python_env_dirs_with_interpreter("python3") + .unwrap_or(build_utils::PythonConfig { include_dir: None, lib_dir: None }); + + generate_bindings(&sources, &platform, &cpp_libs.rdma_include, &python_config, &out_path); + compile_c(&sources, &platform, &cpp_libs.rdma_include); + compile_cpp(&sources, &platform, &cpp_libs.rdma_include, &python_config); + compile_gpu(&sources, &platform, &cpp_libs.rdma_include, &manifest_dir, &out_path); } - Ok(()) + + println!("cargo:rustc-env=CUDA_INCLUDE_PATH={}", platform.include_dir()); + println!("cargo:rustc-cfg=cargo"); } -fn patch_hipified_files_rocm6(hip_src_dir: &Path) -> Result<(), Box> { - println!("cargo:warning=Patching hipify_torch output for ROCm 6.x (HSA dmabuf)..."); - - let cpp_file = hip_src_dir.join("rdmaxcel_hip.cpp"); - if cpp_file.exists() { - let content = fs::read_to_string(&cpp_file)?; - let mut patched_content = content - .replace("#include ", "#include \n#include \n#include \n#include ") - .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") - .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") - .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") - .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") - .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") - .replace("static_cast", "reinterpret_cast") - .replace("static_cast", "reinterpret_cast") - .replace("CUDA_SUCCESS", "hipSuccess") - .replace("CUdevice device", "hipDevice_t device") - .replace("cuDeviceGet(&device", "hipDeviceGet(&device") - .replace("cuDeviceGetAttribute", "hipDeviceGetAttribute") - .replace("cuPointerGetAttribute", "hipPointerGetAttribute") - .replace("CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", "hipDeviceAttributePciBusId") - .replace("CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", "hipDeviceAttributePciDeviceId") - .replace("CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", "hipDeviceAttributePciDomainID") - .replace("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") - .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "0 /* HSA dmabuf */"); - - patched_content = patched_content.replace("cuMemGetHandleForAddressRange(", "hsa_amd_portable_export_dmabuf("); - patched_content = patched_content.replace("hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(start_addr),\n total_size,\n 0 /* HSA dmabuf */,\n 0);", "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(start_addr),\n total_size,\n &fd,\n nullptr);"); - patched_content = patched_content.replace("hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(chunk_start),\n chunk_size,\n 0 /* HSA dmabuf */,\n 0);", "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(chunk_start),\n chunk_size,\n &fd,\n nullptr);"); - patched_content = patched_content - .replace("CUresult cu_result", "hsa_status_t hsa_result") - .replace("hipError_t cu_result", "hsa_status_t hsa_result") - .replace("cu_result != hipSuccess", "hsa_result != HSA_STATUS_SUCCESS") - .replace("if (cu_result", "if (hsa_result"); - patched_content = patched_content.replace("hipPointerAttribute::device", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL"); - fs::write(&cpp_file, patched_content)?; - } +// ============================================================================= +// Platform abstraction +// ============================================================================= - let header_file = hip_src_dir.join("rdmaxcel_hip.h"); - if header_file.exists() { - let content = fs::read_to_string(&header_file)?; - let patched_content = content - .replace("#include \"driver_api.h\"", "#include \"driver_api_hip.h\"") - .replace("CUdeviceptr", "hipDeviceptr_t"); - fs::write(&header_file, patched_content)?; - } +enum Platform { + Cuda { home: String }, + Rocm { home: String, version: (u32, u32) }, +} - let driver_api_h = hip_src_dir.join("driver_api_hip.h"); - if driver_api_h.exists() { - let content = fs::read_to_string(&driver_api_h)?; - let mut patched_content = rename_rdmaxcel_wrappers(&content); - if !patched_content.contains("#include ") { - patched_content = patched_content.replace("#include ", "#include \n#include \n#include "); +impl Platform { + fn include_dir(&self) -> String { + match self { + Platform::Cuda { home } | Platform::Rocm { home, .. } => format!("{}/include", home), } - let old_decl = "hipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n CUmemRangeHandleType handleType,\n unsigned long long flags);"; - let new_decl = "hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf(\n void* ptr,\n size_t size,\n int* fd,\n uint64_t* flags);"; - patched_content = patched_content.replace(old_decl, new_decl); - patched_content = patched_content.replace("CUmemRangeHandleType", "int /* placeholder - ROCm 6.x */"); - patched_content.push_str("\n\n// CUDA-compatible wrapper for monarch_rdma\nhipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n int handleType,\n unsigned long long flags);\n"); - fs::write(&driver_api_h, patched_content)?; } - let driver_api_cpp = hip_src_dir.join("driver_api_hip.cpp"); - if driver_api_cpp.exists() { - let content = fs::read_to_string(&driver_api_cpp)?; - let mut patched_content = rename_rdmaxcel_wrappers(&content); - patched_content = patched_content.replace("#include \"driver_api_hip.h\"", "#include \"driver_api_hip.h\"\n#include \n#include "); - patched_content = patched_content.replace("libcuda.so.1", "libamdhip64.so"); - patched_content = patched_content.replace("dstDevice, srcHost, ByteCount);", "dstDevice, const_cast(srcHost), ByteCount);"); - patched_content = patched_content - .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") - .replace("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_(") - .replace("->cuMemCreate_(", "->hipMemCreate_(") - .replace("->cuMemAddressReserve_(", "->hipMemAddressReserve_(") - .replace("->cuMemMap_(", "->hipMemMap_(") - .replace("->cuMemSetAccess_(", "->hipMemSetAccess_(") - .replace("->cuMemUnmap_(", "->hipMemUnmap_(") - .replace("->cuMemAddressFree_(", "->hipMemAddressFree_(") - .replace("->cuMemRelease_(", "->hipMemRelease_(") - .replace("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_(") - .replace("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_(") - .replace("->cuMemsetD8_v2_(", "->hipMemsetD8_(") - .replace("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_(") - .replace("->cuInit_(", "->hipInit_(") - .replace("->cuDeviceGet_(", "->hipDeviceGet_(") - .replace("->cuDeviceGetCount_(", "->hipGetDeviceCount_(") - .replace("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_(") - .replace("->cuCtxCreate_v2_(", "->hipCtxCreate_(") - .replace("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_(") - .replace("->cuCtxSynchronize_(", "->hipCtxSynchronize_(") - .replace("->cuGetErrorString_(", "->hipDrvGetErrorString_("); - - patched_content = patched_content - .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") - .replace("_(cuMemGetAllocationGranularity)", "_(hipMemGetAllocationGranularity)") - .replace("_(cuMemCreate)", "_(hipMemCreate)") - .replace("_(cuMemAddressReserve)", "_(hipMemAddressReserve)") - .replace("_(cuMemMap)", "_(hipMemMap)") - .replace("_(cuMemSetAccess)", "_(hipMemSetAccess)") - .replace("_(cuMemUnmap)", "_(hipMemUnmap)") - .replace("_(cuMemAddressFree)", "_(hipMemAddressFree)") - .replace("_(cuMemRelease)", "_(hipMemRelease)") - .replace("_(cuMemcpyHtoD_v2)", "_(hipMemcpyHtoD)") - .replace("_(cuMemcpyDtoH_v2)", "_(hipMemcpyDtoH)") - .replace("_(cuMemsetD8_v2)", "_(hipMemsetD8)") - .replace("_(cuPointerGetAttribute)", "_(hipPointerGetAttribute)") - .replace("_(cuInit)", "_(hipInit)") - .replace("_(cuDeviceGet)", "_(hipDeviceGet)") - .replace("_(cuDeviceGetCount)", "_(hipGetDeviceCount)") - .replace("_(cuDeviceGetAttribute)", "_(hipDeviceGetAttribute)") - .replace("_(cuCtxCreate_v2)", "_(hipCtxCreate)") - .replace("_(cuCtxSetCurrent)", "_(hipCtxSetCurrent)") - .replace("_(cuCtxSynchronize)", "_(hipCtxSynchronize)") - .replace("_(cuGetErrorString)", "_(hipDrvGetErrorString)"); - - let old_wrapper = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( - int* handle, - hipDeviceptr_t dptr, - size_t size, - CUmemRangeHandleType handleType, - unsigned long long flags) { - return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( - handle, dptr, size, handleType, flags); -}"#; - let new_wrapper = r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( - void* ptr, - size_t size, - int* fd, - uint64_t* flags) { - // Direct HSA call for ROCm 6.x - bypasses DriverAPI dynamic loading - return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); -}"#; - patched_content = patched_content.replace(old_wrapper, new_wrapper); - let old_wrapper2 = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( - int* handle, - hipDeviceptr_t dptr, - size_t size, - CUmemRangeHandleType handleType, - unsigned long long flags) { - return rdmaxcel::DriverAPI::get()->cuMemGetHandleForAddressRange_( - handle, dptr, size, handleType, flags); -}"#; - patched_content = patched_content.replace(old_wrapper2, new_wrapper); - patched_content = patched_content.replace("CUmemRangeHandleType", "int /* placeholder - ROCm 6.x */"); - patched_content = patched_content.replace("_(hipMemGetHandleForAddressRange) \\", "/* hipMemGetHandleForAddressRange removed for ROCm 6.x - using HSA */ \\"); - patched_content = patched_content.replace("_(hipMemGetHandleForAddressRange) \\", "/* hipMemGetHandleForAddressRange removed for ROCm 6.x - using HSA */ \\"); - - let cuda_compat_wrapper = r#" -// CUDA-compatible wrapper for monarch_rdma - translates to HSA call -hipError_t rdmaxcel_cuMemGetHandleForAddressRange( - int* handle, - hipDeviceptr_t dptr, - size_t size, - int handleType, - unsigned long long flags) { - (void)handleType; // unused - ROCm 6.x only supports dmabuf - (void)flags; // unused - hsa_status_t status = hsa_amd_portable_export_dmabuf( - reinterpret_cast(dptr), - size, - handle, - nullptr); - return (status == HSA_STATUS_SUCCESS) ? hipSuccess : hipErrorUnknown; -} -"#; - patched_content.push_str(cuda_compat_wrapper); - fs::write(&driver_api_cpp, patched_content)?; + fn lib_dir(&self) -> String { + match self { + Platform::Cuda { home } => build_utils::get_cuda_lib_dir(), + Platform::Rocm { home, .. } => { + build_utils::get_rocm_lib_dir().expect("Failed to get ROCm lib dir") + } + } } - Ok(()) -} -fn validate_hipified_files(hip_src_dir: &Path) -> Result<(), Box> { - let required_files = [ - "rdmaxcel_hip.h", - "rdmaxcel_hip.c", - "rdmaxcel_hip.cpp", - "rdmaxcel.hip", - ]; - for file_name in &required_files { - let file_path = hip_src_dir.join(file_name); - if !file_path.exists() { - return Err(format!("Required hipified file {} was not found", file_name).into()); + fn compiler(&self) -> String { + match self { + Platform::Cuda { home } => format!("{}/bin/nvcc", home), + Platform::Rocm { home, .. } => format!("{}/bin/hipcc", home), } } - Ok(()) -} - -/// Hipify sources for rdmaxcel-sys using build_utils::run_hipify_torch -/// and apply rdmaxcel-specific patches based on ROCm version -fn hipify_sources( - src_dir: &Path, - hip_src_dir: &Path, - rocm_version: (u32, u32), -) -> Result<(), Box> { - println!("cargo:warning=Hipifying sources from {} to {}...", src_dir.display(), hip_src_dir.display()); - - // Collect source files to hipify - let files_to_copy = [ - "lib.rs", "rdmaxcel.h", "rdmaxcel.c", "rdmaxcel.cpp", "rdmaxcel.cu", - "test_rdmaxcel.c", "driver_api.h", "driver_api.cpp", - ]; - let source_files: Vec = files_to_copy - .iter() - .map(|f| src_dir.join(f)) - .filter(|p| p.exists()) - .collect(); + fn is_rocm(&self) -> bool { + matches!(self, Platform::Rocm { .. }) + } - // Find project root - let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); - let project_root = manifest_dir.parent().ok_or("Failed to find project root")?; + fn rocm_version(&self) -> (u32, u32) { + match self { + Platform::Rocm { version, .. } => *version, + Platform::Cuda { .. } => (0, 0), + } + } - // Use centralized hipify function from build_utils - build_utils::run_hipify_torch(project_root, &source_files, hip_src_dir) - .map_err(|e| format!("hipify_torch failed: {}", e))?; + fn emit_link_libs(&self) { + match self { + Platform::Cuda { .. } => { + // CUDA: static runtime, dlopen driver API + println!("cargo:rustc-link-lib=static=cudart_static"); + println!("cargo:rustc-link-lib=rt"); + println!("cargo:rustc-link-lib=pthread"); + } + Platform::Rocm { .. } => { + // ROCm: all driver API via dlopen + // Note: hipcc-compiled code still requires libamdhip64.so at runtime + } + } + } - // Apply rdmaxcel-specific patches based on ROCm version - let (major, _minor) = rocm_version; - if major >= 7 { - patch_hipified_files_rocm7(hip_src_dir)?; - } else { - patch_hipified_files_rocm6(hip_src_dir)?; + fn prepare_sources(&self, src_dir: &PathBuf, out_path: &PathBuf) -> Sources { + match self { + Platform::Cuda { .. } => Sources { + dir: src_dir.clone(), + header: src_dir.join("rdmaxcel.h"), + c_source: src_dir.join("rdmaxcel.c"), + cpp_source: src_dir.join("rdmaxcel.cpp"), + gpu_source: src_dir.join("rdmaxcel.cu"), + driver_api: src_dir.join("driver_api.cpp"), + }, + Platform::Rocm { version, .. } => { + let hip_dir = out_path.join("hipified_src"); + hipify_sources(src_dir, &hip_dir, *version); + Sources { + dir: hip_dir.clone(), + header: hip_dir.join("rdmaxcel_hip.h"), + c_source: hip_dir.join("rdmaxcel_hip.c"), + cpp_source: hip_dir.join("rdmaxcel_hip.cpp"), + gpu_source: hip_dir.join("rdmaxcel.hip"), + driver_api: hip_dir.join("driver_api_hip.cpp"), + } + } + } } - Ok(()) -} + fn add_defines(&self, build: &mut cc::Build) { + if let Platform::Rocm { version, .. } = self { + build.define("__HIP_PLATFORM_AMD__", "1"); + build.define("USE_ROCM", "1"); + if version.0 >= 7 { + build.define("ROCM_7_PLUS", "1"); + } else { + build.define("ROCM_6_X", "1"); + } + } + } -fn get_libtorch_include_dirs(python_interpreter: &Path) -> Vec { - let mut include_dirs = Vec::new(); - if let Ok(output) = Command::new(python_interpreter).arg("-c").arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS).output() { - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { - include_dirs.push(PathBuf::from(path)); + fn clang_defines(&self) -> Vec { + match self { + Platform::Cuda { .. } => vec![], + Platform::Rocm { version, .. } => { + let mut defs = vec![ + "-D__HIP_PLATFORM_AMD__=1".into(), + "-DUSE_ROCM=1".into(), + ]; + if version.0 >= 7 { + defs.push("-DROCM_7_PLUS=1".into()); + } else { + defs.push("-DROCM_6_X=1".into()); + } + defs } } } - include_dirs -} -/// Try to get rdma-core config from cpp_static_libs, returns None if not available -fn try_get_cpp_static_libs_config() -> Option { - // Check if the required environment variables are set - if std::env::var("DEP_MONARCH_CPP_STATIC_LIBS_RDMA_INCLUDE").is_ok() { - Some(build_utils::CppStaticLibsConfig::from_env()) - } else { - None + fn compiler_args(&self) -> Vec { + match self { + Platform::Cuda { .. } => vec![ + "--compiler-options".into(), + "-fPIC".into(), + "-std=c++14".into(), + "--expt-extended-lambda".into(), + "-Xcompiler".into(), + "-fPIC".into(), + ], + Platform::Rocm { version, .. } => { + let mut args = vec![ + "-std=c++14".into(), + "-D__HIP_PLATFORM_AMD__=1".into(), + "-DUSE_ROCM=1".into(), + ]; + if version.0 >= 7 { + args.push("-DROCM_7_PLUS=1".into()); + } else { + args.push("-DROCM_6_X=1".into()); + } + args + } + } } } -#[cfg(target_os = "macos")] -fn main() {} - -#[cfg(not(target_os = "macos"))] -fn main() { - // Declare custom cfg options to avoid warnings - println!("cargo::rustc-check-cfg=cfg(cargo)"); - println!("cargo::rustc-check-cfg=cfg(rocm_6_x)"); - println!("cargo::rustc-check-cfg=cfg(rocm_7_plus)"); +struct Sources { + dir: PathBuf, + header: PathBuf, + c_source: PathBuf, + cpp_source: PathBuf, + gpu_source: PathBuf, + driver_api: PathBuf, +} - // Try to get rdma-core config from cpp_static_libs (upstream approach) - // If not available, fall back to dynamic linking - let cpp_static_libs_config = try_get_cpp_static_libs_config(); - let rdma_include = cpp_static_libs_config.as_ref().map(|c| c.rdma_include.clone()); +// ============================================================================= +// Platform detection +// ============================================================================= - if let Some(config) = &cpp_static_libs_config { - // Explicitly emit link directives from the config if it was found. - // This ensures ccan, rdma_util, etc., are linked. - config.emit_link_directives(); - } else { - // Fallback: If metadata failed, check if we should link statically anyway. - // If monarch_cpp_static_libs ran (which it did, per logs), it emitted -L flags - // that cargo propagates automatically. We just need to ensure the libraries - // are on the link list. +fn detect_platform() -> Platform { + // Try ROCm first (ROCm systems may also have CUDA installed) + if let Ok(home) = build_utils::validate_rocm_installation() { + let version = build_utils::get_rocm_version(&home).unwrap_or((6, 0)); + println!("cargo:warning=Using HIP/ROCm {}.{} from {}", version.0, version.1, home); - // Link main libs (could be static or dynamic, but if -L is present, static wins) - println!("cargo:rustc-link-lib=ibverbs"); - println!("cargo:rustc-link-lib=mlx5"); + if version.0 >= 7 { + println!("cargo:rustc-cfg=rocm_7_plus"); + } else { + println!("cargo:rustc-cfg=rocm_6_x"); + } - // FORCE link helpers: ccan and rdma_util. - // These are required by the static version of libmlx5.a/libibverbs.a - println!("cargo:rustc-link-lib=static=rdma_util"); - println!("cargo:rustc-link-lib=static=ccan"); + return Platform::Rocm { home, version }; } - let (is_rocm, compute_home, compute_lib_names, rocm_version) = - if let Ok(rocm_home) = build_utils::validate_rocm_installation() { - let version = build_utils::get_rocm_version(&rocm_home).unwrap_or((6, 0)); - println!("cargo:warning=Using HIP/ROCm {} from {}", format!("{}.{}", version.0, version.1), rocm_home); - if version.0 >= 7 { println!("cargo:rustc-cfg=rocm_7_plus"); } else { println!("cargo:rustc-cfg=rocm_6_x"); } - (true, rocm_home, vec!["amdhip64", "hsa-runtime64"], version) - } else if let Ok(cuda_home) = build_utils::validate_cuda_installation() { - println!("cargo:warning=Using CUDA from {}", cuda_home); - (false, cuda_home, vec![], (0, 0)) // CUDA libs handled below - } else { - eprintln!("Error: Neither CUDA nor ROCm installation found!"); - std::process::exit(1); - }; + // Fall back to CUDA + if let Ok(home) = build_utils::validate_cuda_installation() { + println!("cargo:warning=Using CUDA from {}", home); + return Platform::Cuda { home }; + } + + eprintln!("Error: Neither CUDA nor ROCm installation found!"); + build_utils::print_cuda_error_help(); + std::process::exit(1); +} + +// ============================================================================= +// Hipification (ROCm only) +// ============================================================================= + +fn hipify_sources(src_dir: &PathBuf, hip_dir: &PathBuf, version: (u32, u32)) { + println!("cargo:warning=Hipifying sources to {}...", hip_dir.display()); + + let files: Vec = [ + "lib.rs", "rdmaxcel.h", "rdmaxcel.c", "rdmaxcel.cpp", + "rdmaxcel.cu", "test_rdmaxcel.c", "driver_api.h", "driver_api.cpp" + ].iter() + .map(|f| src_dir.join(f)) + .filter(|p| p.exists()) + .collect(); let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); - let src_dir = manifest_dir.join("src"); - let python_interpreter = build_utils::find_python_interpreter(); - let compute_include_path = format!("{}/include", compute_home); - println!("cargo:rustc-env=CUDA_INCLUDE_PATH={}", compute_include_path); - - let python_config = match build_utils::python_env_dirs_with_interpreter("python3") { - Ok(config) => config, - Err(_) => build_utils::PythonConfig { include_dir: None, lib_dir: None }, - }; - - // Platform-specific library linking - if is_rocm { - let compute_lib_dir = build_utils::get_rocm_lib_dir().unwrap(); - println!("cargo:rustc-link-search=native={}", compute_lib_dir); - for lib_name in &compute_lib_names { - println!("cargo:rustc-link-lib={}", lib_name); - } - } else { - // CUDA: Link cudart statically (upstream approach) - let cuda_lib_dir = build_utils::get_cuda_lib_dir(); - println!("cargo:rustc-link-search=native={}", cuda_lib_dir); - // Note: libcuda is now loaded dynamically via dlopen in driver_api.cpp - // Link cudart statically (CUDA Runtime API) - println!("cargo:rustc-link-lib=static=cudart_static"); - // cudart_static requires linking against librt and libpthread - println!("cargo:rustc-link-lib=rt"); - println!("cargo:rustc-link-lib=pthread"); - } - println!("cargo:rustc-link-lib=dl"); + let project_root = manifest_dir.parent().expect("Failed to find project root"); - let use_pytorch_apis = build_utils::get_env_var_with_rerun("TORCH_SYS_USE_PYTORCH_APIS").unwrap_or_else(|_| "1".to_owned()); - let libtorch_include_dirs: Vec = if use_pytorch_apis == "1" { - get_libtorch_include_dirs(&python_interpreter) + build_utils::run_hipify_torch(project_root, &files, hip_dir) + .expect("hipify_torch failed"); + + // Apply version-specific patches + if version.0 >= 7 { + build_utils::rocm::patch_hipified_files_rocm7(hip_dir) + .expect("ROCm 7+ patching failed"); } else { - Vec::new() - }; - - if use_pytorch_apis == "1" { - if let Ok(output) = Command::new(&python_interpreter).arg("-c").arg(build_utils::PYTHON_PRINT_PYTORCH_DETAILS).output() { - for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { - println!("cargo:rustc-link-search=native={}", path); - break; - } - } - } - println!("cargo:rustc-link-lib=torch_cpu"); - println!("cargo:rustc-link-lib=torch"); - println!("cargo:rustc-link-lib=c10"); - if is_rocm { println!("cargo:rustc-link-lib=c10_hip"); } else { println!("cargo:rustc-link-lib=c10_cuda"); } + build_utils::rocm::patch_hipified_files_rocm6(hip_dir) + .expect("ROCm 6.x patching failed"); } - match env::var("OUT_DIR") { - Ok(out_dir) => { - let out_path = PathBuf::from(&out_dir); - let (code_dir, header_path, c_source_path, cpp_source_path, cuda_source_path, driver_api_cpp_path); - - if is_rocm { - let hip_src_dir = out_path.join("hipified_src"); - hipify_sources(&src_dir, &hip_src_dir, rocm_version).expect("Failed to hipify sources"); - validate_hipified_files(&hip_src_dir).expect("Hipified files validation failed"); - code_dir = hip_src_dir.clone(); - header_path = hip_src_dir.join("rdmaxcel_hip.h"); - c_source_path = hip_src_dir.join("rdmaxcel_hip.c"); - cpp_source_path = hip_src_dir.join("rdmaxcel_hip.cpp"); - cuda_source_path = hip_src_dir.join("rdmaxcel.hip"); - driver_api_cpp_path = hip_src_dir.join("driver_api_hip.cpp"); - } else { - code_dir = src_dir.clone(); - header_path = src_dir.join("rdmaxcel.h"); - c_source_path = src_dir.join("rdmaxcel.c"); - cpp_source_path = src_dir.join("rdmaxcel.cpp"); - cuda_source_path = src_dir.join("rdmaxcel.cu"); - driver_api_cpp_path = src_dir.join("driver_api.cpp"); - } + build_utils::rocm::validate_hipified_files(hip_dir) + .expect("Hipified file validation failed"); +} - // Bindgen setup - let mut builder = bindgen::Builder::default() - .header(header_path.to_string_lossy()) - .clang_arg("-x").clang_arg("c++").clang_arg("-std=c++14") - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - .allowlist_function("ibv_.*").allowlist_function("mlx5dv_.*").allowlist_function("mlx5_wqe_.*") - .allowlist_function("create_qp").allowlist_function("create_mlx5dv_.*") - .allowlist_function("register_cuda_memory").allowlist_function("register_hip_memory") - .allowlist_function("db_ring").allowlist_function("cqe_poll") - .allowlist_function("send_wqe").allowlist_function("recv_wqe") - .allowlist_function("launch_db_ring").allowlist_function("launch_cqe_poll") - .allowlist_function("launch_send_wqe").allowlist_function("launch_recv_wqe") - .allowlist_function("rdma_get_active_segment_count").allowlist_function("rdma_get_all_segment_info") - .allowlist_function("pt_cuda_allocator_compatibility").allowlist_function("pt_hip_allocator_compatibility") - .allowlist_function("register_segments").allowlist_function("deregister_segments") - .allowlist_function("register_dmabuf_buffer").allowlist_function("get_hip_pci_address_from_ptr") - .allowlist_function("get_cuda_pci_address_from_ptr") - .allowlist_function("rdmaxcel_cu.*").allowlist_function("rdmaxcel_hip.*").allowlist_function("rdmaxcel_hsa.*") - .allowlist_function("rdmaxcel_qp_.*").allowlist_function("rdmaxcel_print_device_info") - .allowlist_function("rdmaxcel_error_string").allowlist_function("completion_cache_.*") - .allowlist_function("poll_cq_with_cache").allowlist_function("rdmaxcel_register_segment_scanner") - .allowlist_type("rdmaxcel_qp_t").allowlist_type("rdmaxcel_qp") - .allowlist_type("rdmaxcel_error_code_t").allowlist_type("completion_cache_t") - .allowlist_type("completion_cache").allowlist_type("completion_node_t") - .allowlist_type("completion_node").allowlist_type("poll_context_t") - .allowlist_type("poll_context").allowlist_type("rdma_qp_type_t") - .allowlist_type("rdmaxcel_segment_scanner_fn") - .allowlist_type("rdmaxcel_scanned_segment_t") - .allowlist_type("CUdeviceptr").allowlist_type("CUdevice").allowlist_type("CUresult") - .allowlist_type("CUcontext").allowlist_type("CUmemRangeHandleType") - .allowlist_var("CUDA_SUCCESS").allowlist_var("CU_.*") - .allowlist_type("hipDeviceptr_t").allowlist_type("hipDevice_t").allowlist_type("hipError_t") - .allowlist_type("hipCtx_t").allowlist_type("hipPointer_attribute") - .allowlist_var("hipSuccess").allowlist_var("HIP_.*") - .allowlist_type("hsa_status_t").allowlist_var("HSA_STATUS_SUCCESS") - .allowlist_type("ibv_.*").allowlist_type("mlx5dv_.*").allowlist_type("mlx5_wqe_.*") - .allowlist_type("cqe_poll_result_t").allowlist_type("wqe_params_t") - .allowlist_type("cqe_poll_params_t").allowlist_type("rdma_segment_info_t") - .allowlist_var("MLX5_.*").allowlist_var("IBV_.*").allowlist_var("RDMA_QP_TYPE_.*") - .blocklist_type("ibv_wc").blocklist_type("mlx5_wqe_ctrl_seg") - .bitfield_enum("ibv_access_flags").bitfield_enum("ibv_qp_attr_mask") - .bitfield_enum("ibv_wc_flags").bitfield_enum("ibv_send_flags") - .bitfield_enum("ibv_port_cap_flags").constified_enum_module("ibv_qp_type") - .constified_enum_module("ibv_qp_state").constified_enum_module("ibv_port_state") - .constified_enum_module("ibv_wc_opcode").constified_enum_module("ibv_wr_opcode") - .constified_enum_module("ibv_wc_status") - .derive_default(true).prepend_enum_name(false); - - builder = builder.clang_arg(format!("-I{}", compute_include_path)); - - // Add rdma-core include path - if let Some(ref rdma_inc) = rdma_include { - builder = builder.clang_arg(format!("-I{}", rdma_inc)); - } +// ============================================================================= +// Compilation +// ============================================================================= - if is_rocm { - builder = builder.clang_arg("-D__HIP_PLATFORM_AMD__=1").clang_arg("-DUSE_ROCM=1"); - if rocm_version.0 >= 7 { builder = builder.clang_arg("-DROCM_7_PLUS=1"); } else { builder = builder.clang_arg("-DROCM_6_X=1"); } - } - if let Some(include_dir) = &python_config.include_dir { - builder = builder.clang_arg(format!("-I{}", include_dir)); - } - let bindings = builder.generate().expect("Unable to generate bindings"); - bindings.write_to_file(out_path.join("bindings.rs")).expect("Couldn't write bindings"); - - println!("cargo:rustc-cfg=cargo"); - - // Compile C files (rdmaxcel.c) - if c_source_path.exists() { - let mut build = cc::Build::new(); - build.file(&c_source_path).include(&code_dir).flag("-fPIC"); - build.include(&compute_include_path); - if let Some(ref rdma_inc) = rdma_include { - build.include(rdma_inc); - } - if is_rocm { - build.define("__HIP_PLATFORM_AMD__", "1").define("USE_ROCM", "1"); - if rocm_version.0 >= 7 { build.define("ROCM_7_PLUS", "1"); } else { build.define("ROCM_6_X", "1"); } - } - build.compile("rdmaxcel"); - } +fn generate_bindings( + sources: &Sources, + platform: &Platform, + rdma_include: &str, + python_config: &build_utils::PythonConfig, + out_path: &PathBuf, +) { + let mut builder = bindgen::Builder::default() + .header(sources.header.to_string_lossy()) + .clang_arg("-x").clang_arg("c++").clang_arg("-std=c++14") + .clang_arg(format!("-I{}", platform.include_dir())) + .clang_arg(format!("-I{}", rdma_include)) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + // Functions + .allowlist_function("ibv_.*").allowlist_function("mlx5dv_.*") + .allowlist_function("create_qp").allowlist_function("create_mlx5dv_.*") + .allowlist_function("register_cuda_memory").allowlist_function("register_hip_memory") + .allowlist_function("db_ring").allowlist_function("cqe_poll") + .allowlist_function("send_wqe").allowlist_function("recv_wqe") + .allowlist_function("launch_.*").allowlist_function("rdma_get_.*") + .allowlist_function("pt_.*_allocator_compatibility") + .allowlist_function("register_segments").allowlist_function("deregister_segments") + .allowlist_function("register_dmabuf_buffer") + .allowlist_function("get_.*_pci_address_from_ptr") + .allowlist_function("rdmaxcel_.*") + .allowlist_function("completion_cache_.*").allowlist_function("poll_cq_with_cache") + // Types + .allowlist_type("rdmaxcel_.*").allowlist_type("completion_.*") + .allowlist_type("poll_context.*").allowlist_type("rdma_qp_type_t") + .allowlist_type("CU.*").allowlist_type("hip.*").allowlist_type("hsa_status_t") + .allowlist_type("ibv_.*").allowlist_type("mlx5.*") + .allowlist_type("cqe_poll_.*").allowlist_type("wqe_params_t") + .allowlist_type("rdma_segment_info_t") + // Vars + .allowlist_var("CUDA_SUCCESS").allowlist_var("CU_.*") + .allowlist_var("hipSuccess").allowlist_var("HIP_.*") + .allowlist_var("HSA_STATUS_SUCCESS") + .allowlist_var("MLX5_.*").allowlist_var("IBV_.*").allowlist_var("RDMA_QP_TYPE_.*") + // Config + .blocklist_type("ibv_wc").blocklist_type("mlx5_wqe_ctrl_seg") + .bitfield_enum("ibv_access_flags").bitfield_enum("ibv_qp_attr_mask") + .bitfield_enum("ibv_wc_flags").bitfield_enum("ibv_send_flags") + .bitfield_enum("ibv_port_cap_flags") + .constified_enum_module("ibv_qp_type").constified_enum_module("ibv_qp_state") + .constified_enum_module("ibv_port_state").constified_enum_module("ibv_wc_opcode") + .constified_enum_module("ibv_wr_opcode").constified_enum_module("ibv_wc_status") + .derive_default(true).prepend_enum_name(false); + + for def in platform.clang_defines() { + builder = builder.clang_arg(def); + } + + if let Some(ref dir) = python_config.include_dir { + builder = builder.clang_arg(format!("-I{}", dir)); + } - // Compile C++ files (rdmaxcel.cpp) - if cpp_source_path.exists() { - let mut cpp_build = cc::Build::new(); - cpp_build.file(&cpp_source_path).include(&code_dir).flag("-fPIC").cpp(true).flag("-std=c++14"); - if let Some(ref rdma_inc) = rdma_include { - cpp_build.include(rdma_inc); - } - // Suppress deprecated API warnings for HIP context management APIs (deprecated in ROCm 6.x) - if is_rocm { - cpp_build.flag("-Wno-deprecated-declarations"); - } - if driver_api_cpp_path.exists() { cpp_build.file(&driver_api_cpp_path); } - cpp_build.include(&compute_include_path); - if is_rocm { - cpp_build.define("__HIP_PLATFORM_AMD__", "1").define("USE_ROCM", "1"); - if rocm_version.0 >= 7 { cpp_build.define("ROCM_7_PLUS", "1"); } else { cpp_build.define("ROCM_6_X", "1"); } - } - for include_dir in &libtorch_include_dirs { cpp_build.include(include_dir); } - if let Some(include_dir) = &python_config.include_dir { cpp_build.include(include_dir); } - cpp_build.compile("rdmaxcel_cpp"); + builder.generate() + .expect("Unable to generate bindings") + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings"); +} - // Statically link libstdc++ to avoid runtime dependency on system libstdc++ (upstream) - build_utils::link_libstdcpp_static(); - } +fn compile_c(sources: &Sources, platform: &Platform, rdma_include: &str) { + if !sources.c_source.exists() { return; } + + let mut build = cc::Build::new(); + build + .file(&sources.c_source) + .include(&sources.dir) + .include(platform.include_dir()) + .include(rdma_include) + .flag("-fPIC"); + + platform.add_defines(&mut build); + build.compile("rdmaxcel"); +} - // Compile CUDA/HIP files - if cuda_source_path.exists() { - let compiler_path = if is_rocm { format!("{}/bin/hipcc", compute_home) } else { format!("{}/bin/nvcc", compute_home) }; - let cuda_build_dir = format!("{}/target/cuda_build", manifest_dir.display()); - std::fs::create_dir_all(&cuda_build_dir).expect("Failed to create CUDA build directory"); - let cuda_obj_path = format!("{}/rdmaxcel_cuda.o", cuda_build_dir); - let cuda_lib_path = format!("{}/librdmaxcel_cuda.a", cuda_build_dir); - - let mut compiler_args: Vec = vec![ - "-c".to_string(), - cuda_source_path.to_str().unwrap().to_string(), - "-o".to_string(), - cuda_obj_path.clone(), - "-fPIC".to_string(), - format!("-I{}", compute_include_path), - format!("-I{}", code_dir.display()), - "-I/usr/include".to_string(), - "-I/usr/include/infiniband".to_string(), - ]; +fn compile_cpp( + sources: &Sources, + platform: &Platform, + rdma_include: &str, + python_config: &build_utils::PythonConfig, +) { + if !sources.cpp_source.exists() { return; } + + let mut build = cc::Build::new(); + build + .file(&sources.cpp_source) + .include(&sources.dir) + .include(platform.include_dir()) + .include(rdma_include) + .flag("-fPIC") + .cpp(true) + .flag("-std=c++14"); + + if sources.driver_api.exists() { + build.file(&sources.driver_api); + } + + platform.add_defines(&mut build); + + if platform.is_rocm() { + build.flag("-Wno-deprecated-declarations"); + } + + if let Some(ref dir) = python_config.include_dir { + build.include(dir); + } + + build.compile("rdmaxcel_cpp"); + build_utils::link_libstdcpp_static(); +} - if let Some(ref rdma_inc) = rdma_include { - compiler_args.push(format!("-I{}", rdma_inc)); - } +fn compile_gpu( + sources: &Sources, + platform: &Platform, + rdma_include: &str, + manifest_dir: &PathBuf, + out_path: &PathBuf, +) { + if !sources.gpu_source.exists() { return; } + + let build_dir = format!("{}/target/cuda_build", manifest_dir.display()); + std::fs::create_dir_all(&build_dir).expect("Failed to create build directory"); + + let obj_path = format!("{}/rdmaxcel_cuda.o", build_dir); + let lib_path = format!("{}/librdmaxcel_cuda.a", build_dir); + + let mut args = vec![ + "-c".to_string(), + sources.gpu_source.to_string_lossy().to_string(), + "-o".to_string(), + obj_path.clone(), + "-fPIC".to_string(), + format!("-I{}", platform.include_dir()), + format!("-I{}", sources.dir.display()), + format!("-I{}", rdma_include), + "-I/usr/include".to_string(), + "-I/usr/include/infiniband".to_string(), + ]; + args.extend(platform.compiler_args()); + + let output = Command::new(platform.compiler()) + .args(&args) + .output() + .expect("Failed to run GPU compiler"); + + if !output.status.success() { + panic!( + "GPU compilation failed:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + } - let compiler_output = if is_rocm { - compiler_args.push("-std=c++14".to_string()); - compiler_args.push("-D__HIP_PLATFORM_AMD__=1".to_string()); - compiler_args.push("-DUSE_ROCM=1".to_string()); - if rocm_version.0 >= 7 { - compiler_args.push("-DROCM_7_PLUS=1".to_string()); - } else { - compiler_args.push("-DROCM_6_X=1".to_string()); - } - Command::new(&compiler_path).args(&compiler_args).output() - } else { - compiler_args.insert(4, "--compiler-options".to_string()); - compiler_args.insert(6, "-std=c++14".to_string()); - compiler_args.insert(7, "--expt-extended-lambda".to_string()); - compiler_args.insert(8, "-Xcompiler".to_string()); - compiler_args.insert(9, "-fPIC".to_string()); - Command::new(&compiler_path).args(&compiler_args).output() - }; - - match compiler_output { - Ok(output) => { - if !output.status.success() { panic!("Failed to compile CUDA/HIP source: {}", String::from_utf8_lossy(&output.stderr)); } - } - Err(e) => panic!("Failed to run compiler: {}", e), - } + let ar_output = Command::new("ar") + .args(["rcs", &lib_path, &obj_path]) + .output(); - let ar_output = Command::new("ar").args(["rcs", &cuda_lib_path, &cuda_obj_path]).output(); - if let Ok(output) = ar_output { - if !output.status.success() { panic!("Failed to create static library"); } - println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); - println!("cargo:rustc-link-search=native={}", cuda_build_dir); - if let Err(e) = std::fs::copy(&cuda_lib_path, out_path.join("librdmaxcel_cuda.a")) { - eprintln!("Warning: Failed to copy CUDA library: {}", e); - } - } - } + if let Ok(out) = ar_output { + if out.status.success() { + println!("cargo:rustc-link-lib=static=rdmaxcel_cuda"); + println!("cargo:rustc-link-search=native={}", build_dir); + let _ = std::fs::copy(&lib_path, out_path.join("librdmaxcel_cuda.a")); } - Err(_) => println!("Note: OUT_DIR not set"), } } From 49d36c28bbb74952293dbb37506e32d52d80fc80 Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Fri, 19 Dec 2025 20:20:07 +0000 Subject: [PATCH 11/12] build successfully on rocm7.0 --- build_utils/src/rocm.rs | 464 ++++++++++++++++++++-------------------- rdmaxcel-sys/src/lib.rs | 55 ++--- 2 files changed, 258 insertions(+), 261 deletions(-) diff --git a/build_utils/src/rocm.rs b/build_utils/src/rocm.rs index ab0b52d44..70d1ccd15 100644 --- a/build_utils/src/rocm.rs +++ b/build_utils/src/rocm.rs @@ -12,63 +12,39 @@ //! compatibility with different ROCm versions: //! - [`patch_hipified_files_rocm7`] for ROCm 7.0+ (native hipMemGetHandleForAddressRange) //! - [`patch_hipified_files_rocm6`] for ROCm 6.x (uses HSA hsa_amd_portable_export_dmabuf) +//! +//! IMPORTANT: Both paths keep wrapper function names as `rdmaxcel_cu*` for API stability. +//! Only the internal implementations differ (HIP vs HSA). use std::fs; use std::path::Path; // ============================================================================= -// Replacement tables - reduces duplication and makes patches easier to maintain +// Replacement tables // ============================================================================= -/// CUDA → HIP wrapper function name mappings -const WRAPPER_REPLACEMENTS: &[(&str, &str)] = &[ - ("rdmaxcel_cuMemGetAllocationGranularity", "rdmaxcel_hipMemGetAllocationGranularity"), - ("rdmaxcel_cuMemCreate", "rdmaxcel_hipMemCreate"), - ("rdmaxcel_cuMemAddressReserve", "rdmaxcel_hipMemAddressReserve"), - ("rdmaxcel_cuMemMap", "rdmaxcel_hipMemMap"), - ("rdmaxcel_cuMemSetAccess", "rdmaxcel_hipMemSetAccess"), - ("rdmaxcel_cuMemUnmap", "rdmaxcel_hipMemUnmap"), - ("rdmaxcel_cuMemAddressFree", "rdmaxcel_hipMemAddressFree"), - ("rdmaxcel_cuMemRelease", "rdmaxcel_hipMemRelease"), - ("rdmaxcel_cuMemcpyHtoD_v2", "rdmaxcel_hipMemcpyHtoD"), - ("rdmaxcel_cuMemcpyDtoH_v2", "rdmaxcel_hipMemcpyDtoH"), - ("rdmaxcel_cuMemsetD8_v2", "rdmaxcel_hipMemsetD8"), - ("rdmaxcel_cuPointerGetAttribute", "rdmaxcel_hipPointerGetAttribute"), - ("rdmaxcel_cuInit", "rdmaxcel_hipInit"), - ("rdmaxcel_cuDeviceGetCount", "rdmaxcel_hipDeviceGetCount"), - ("rdmaxcel_cuDeviceGetAttribute", "rdmaxcel_hipDeviceGetAttribute"), - ("rdmaxcel_cuDeviceGet", "rdmaxcel_hipDeviceGet"), - ("rdmaxcel_cuCtxCreate_v2", "rdmaxcel_hipCtxCreate"), - ("rdmaxcel_cuCtxSetCurrent", "rdmaxcel_hipCtxSetCurrent"), - ("rdmaxcel_cuGetErrorString", "rdmaxcel_hipGetErrorString"), +/// CUDA CU_* constants → HIP equivalents +/// hipify_torch does not convert these constants in rdmaxcel_hip.cpp +const CUDA_CONSTANT_REPLACEMENTS: &[(&str, &str)] = &[ + ("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "hipMemRangeHandleTypeDmaBufFd"), + ("CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", "hipDeviceAttributePciBusId"), + ("CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", "hipDeviceAttributePciDeviceId"), + ("CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", "hipDeviceAttributePciDomainID"), + ("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL"), + ("CUDA_SUCCESS", "hipSuccess"), ]; -/// Driver API function pointer replacements (->func_() calls) -const DRIVER_API_PTR_REPLACEMENTS: &[(&str, &str)] = &[ - ("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_("), - ("->cuMemCreate_(", "->hipMemCreate_("), - ("->cuMemAddressReserve_(", "->hipMemAddressReserve_("), - ("->cuMemMap_(", "->hipMemMap_("), - ("->cuMemSetAccess_(", "->hipMemSetAccess_("), - ("->cuMemUnmap_(", "->hipMemUnmap_("), - ("->cuMemAddressFree_(", "->hipMemAddressFree_("), - ("->cuMemRelease_(", "->hipMemRelease_("), - ("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_("), - ("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_("), - ("->cuMemsetD8_v2_(", "->hipMemsetD8_("), - ("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_("), - ("->cuInit_(", "->hipInit_("), - ("->cuDeviceGet_(", "->hipDeviceGet_("), - ("->cuDeviceGetCount_(", "->hipGetDeviceCount_("), - ("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_("), - ("->cuCtxCreate_v2_(", "->hipCtxCreate_("), - ("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_("), - ("->cuCtxSynchronize_(", "->hipCtxSynchronize_("), - ("->cuGetErrorString_(", "->hipDrvGetErrorString_("), +/// CUDA type replacements that hipify_torch may miss +const CUDA_TYPE_REPLACEMENTS: &[(&str, &str)] = &[ + ("CUresult", "hipError_t"), + ("CUdevice device", "hipDevice_t device"), + ("CUmemRangeHandleType", "hipMemRangeHandleType"), ]; -/// Driver API macro replacements (_(func) entries) -const DRIVER_API_MACRO_REPLACEMENTS: &[(&str, &str)] = &[ +/// Macro entry replacements for driver_api_hip.cpp: _(cuXxx) → _(hipXxx) +/// These are entries in the RDMAXCEL_CUDA_DRIVER_API macro for dlsym lookups +const MACRO_ENTRY_REPLACEMENTS: &[(&str, &str)] = &[ + ("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)"), ("_(cuMemGetAllocationGranularity)", "_(hipMemGetAllocationGranularity)"), ("_(cuMemCreate)", "_(hipMemCreate)"), ("_(cuMemAddressReserve)", "_(hipMemAddressReserve)"), @@ -91,20 +67,55 @@ const DRIVER_API_MACRO_REPLACEMENTS: &[(&str, &str)] = &[ ("_(cuGetErrorString)", "_(hipDrvGetErrorString)"), ]; +/// Struct member access replacements for driver_api_hip.cpp wrapper implementations +/// These fix the ->cuXxx_( calls inside the wrapper functions +const MEMBER_ACCESS_REPLACEMENTS: &[(&str, &str)] = &[ + ("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_("), + ("->cuMemGetAllocationGranularity_(", "->hipMemGetAllocationGranularity_("), + ("->cuMemCreate_(", "->hipMemCreate_("), + ("->cuMemAddressReserve_(", "->hipMemAddressReserve_("), + ("->cuMemMap_(", "->hipMemMap_("), + ("->cuMemSetAccess_(", "->hipMemSetAccess_("), + ("->cuMemUnmap_(", "->hipMemUnmap_("), + ("->cuMemAddressFree_(", "->hipMemAddressFree_("), + ("->cuMemRelease_(", "->hipMemRelease_("), + ("->cuMemcpyHtoD_v2_(", "->hipMemcpyHtoD_("), + ("->cuMemcpyDtoH_v2_(", "->hipMemcpyDtoH_("), + ("->cuMemsetD8_v2_(", "->hipMemsetD8_("), + ("->cuPointerGetAttribute_(", "->hipPointerGetAttribute_("), + ("->cuInit_(", "->hipInit_("), + ("->cuDeviceGet_(", "->hipDeviceGet_("), + ("->cuDeviceGetCount_(", "->hipGetDeviceCount_("), + ("->cuDeviceGetAttribute_(", "->hipDeviceGetAttribute_("), + ("->cuCtxCreate_v2_(", "->hipCtxCreate_("), + ("->cuCtxSetCurrent_(", "->hipCtxSetCurrent_("), + ("->cuCtxSynchronize_(", "->hipCtxSynchronize_("), + ("->cuGetErrorString_(", "->hipDrvGetErrorString_("), +]; + // ============================================================================= // Public API // ============================================================================= /// Apply ROCm 7+ specific patches to hipified files. /// -/// ROCm 7+ has native `hipMemGetHandleForAddressRange` support, so we can -/// use a straightforward CUDA→HIP mapping. +/// ROCm 7+ has native `hipMemGetHandleForAddressRange` support. +/// +/// Key design: We keep wrapper function names as `rdmaxcel_cu*` for API stability. +/// Only the internal HIP API calls are converted. pub fn patch_hipified_files_rocm7(hip_src_dir: &Path) -> Result<(), Box> { println!("cargo:warning=Patching hipified sources for ROCm 7.0+..."); - patch_file(hip_src_dir, "rdmaxcel_hip.cpp", patch_rdmaxcel_cpp_common)?; + // rdmaxcel_hip.cpp - fix constants and types, keep wrapper function names as rdmaxcel_cu* + patch_file(hip_src_dir, "rdmaxcel_hip.cpp", |c| patch_rdmaxcel_cpp_rocm7(&c))?; + + // Header needs driver_api include path fix patch_file(hip_src_dir, "rdmaxcel_hip.h", patch_rdmaxcel_h)?; + + // driver_api_hip.h - fix any remaining CUDA types patch_file(hip_src_dir, "driver_api_hip.h", |c| patch_driver_api_h_rocm7(&c))?; + + // driver_api_hip.cpp - comprehensive patching: macro entries, member access, types patch_file(hip_src_dir, "driver_api_hip.cpp", |c| patch_driver_api_cpp_rocm7(&c))?; Ok(()) @@ -113,10 +124,9 @@ pub fn patch_hipified_files_rocm7(hip_src_dir: &Path) -> Result<(), Box Result<(), Box> { println!("cargo:warning=Patching hipified sources for ROCm 6.x (HSA dmabuf)..."); @@ -135,7 +145,7 @@ pub fn patch_hipified_files_rocm6(hip_src_dir: &Path) -> Result<(), Box Result<(), Box> { const REQUIRED: &[&str] = &["rdmaxcel_hip.h", "rdmaxcel_hip.c", "rdmaxcel_hip.cpp", "rdmaxcel.hip"]; - + for name in REQUIRED { let path = hip_src_dir.join(name); if !path.exists() { @@ -176,7 +186,7 @@ fn apply_replacements(content: &str, replacements: &[(&str, &str)]) -> String { } // ============================================================================= -// Patch implementations - shared +// ROCm 7+ patches // ============================================================================= fn patch_rdmaxcel_h(content: &str) -> String { @@ -185,50 +195,103 @@ fn patch_rdmaxcel_h(content: &str) -> String { .replace("CUdeviceptr", "hipDeviceptr_t") } -fn patch_rdmaxcel_cpp_common(content: &str) -> String { - content - .replace("#include ", - "#include \n#include ") +/// Patch rdmaxcel_hip.cpp for ROCm 7+ +/// +/// IMPORTANT: We do NOT rename wrapper function calls (rdmaxcel_cu* stays rdmaxcel_cu*) +/// We only fix: +/// - CU_* constants → HIP equivalents +/// - Type conversions +/// - c10 namespace references +fn patch_rdmaxcel_cpp_rocm7(content: &str) -> String { + let mut result = content.to_string(); + + // Add hip_version.h include + result = result.replace( + "#include ", + "#include \n#include " + ); + + // Fix c10 namespace (hipify sometimes misses nested references) + result = result .replace("c10::cuda::CUDACachingAllocator", "c10::hip::HIPCachingAllocator") .replace("c10::cuda::CUDAAllocatorConfig", "c10::hip::HIPAllocatorConfig") - .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", + .replace("c10::hip::HIPCachingAllocator::CUDAAllocatorConfig", "c10::hip::HIPCachingAllocator::HIPAllocatorConfig") - .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") - .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") + .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::"); + + // Fix static_cast to reinterpret_cast for device pointers + result = result .replace("static_cast", "reinterpret_cast") - .replace("static_cast", "reinterpret_cast") - .replace("CUDA_SUCCESS", "hipSuccess") - .replace("CUresult", "hipError_t") + .replace("static_cast", "reinterpret_cast"); + + // Fix PCI domain attribute case (hipify produces wrong case) + result = result.replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID"); + + // Apply CU_* constant replacements + result = apply_replacements(&result, CUDA_CONSTANT_REPLACEMENTS); + + // Apply type replacements + result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + + // NOTE: We intentionally do NOT rename wrapper function calls here. + // The wrapper functions keep their rdmaxcel_cu* names for API stability. + + result } -// ============================================================================= -// ROCm 7+ specific patches -// ============================================================================= - +/// Patch driver_api_hip.h for ROCm 7+ +/// hipify_torch handles most conversions, but may miss some types fn patch_driver_api_h_rocm7(content: &str) -> String { - let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); - result = result - .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") - .replace("CUmemRangeHandleType", "hipMemRangeHandleType"); - result + // Apply type replacements (hipify may miss CUmemRangeHandleType) + // Do NOT rename wrapper function declarations - they stay as rdmaxcel_cu* + apply_replacements(content, CUDA_TYPE_REPLACEMENTS) } +/// Patch driver_api_hip.cpp for ROCm 7+ +/// This needs comprehensive patching because hipify_torch doesn't convert: +/// 1. Macro entries like _(cuMemCreate) in RDMAXCEL_CUDA_DRIVER_API +/// 2. Struct member access like ->cuMemCreate_( in wrapper implementations +/// 3. Some types +/// +/// NOTE: Wrapper function names (rdmaxcel_cu*) are NOT changed. fn patch_driver_api_cpp_rocm7(content: &str) -> String { - let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); - result = apply_replacements(&result, DRIVER_API_PTR_REPLACEMENTS); - result = apply_replacements(&result, DRIVER_API_MACRO_REPLACEMENTS); + let mut result = content.to_string(); + + // Fix library name + result = result.replace("libcuda.so.1", "libamdhip64.so"); + + // Fix runtime header + result = result.replace("#include ", "#include "); + result = result.replace("cudaFree", "hipFree"); + + // Apply macro entry replacements: _(cuXxx) → _(hipXxx) + // These are the actual HIP function names for dlsym lookup + result = apply_replacements(&result, MACRO_ENTRY_REPLACEMENTS); + + // Apply struct member access replacements: ->cuXxx_( → ->hipXxx_( + // These are the function pointers stored in the DriverAPI struct + result = apply_replacements(&result, MEMBER_ACCESS_REPLACEMENTS); + + // Apply type replacements + result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + + // Fix const_cast for HtoD (srcHost needs to be non-const for HIP) + result = result.replace( + "dstDevice, srcHost, ByteCount);", + "dstDevice, const_cast(srcHost), ByteCount);" + ); + + // NOTE: Wrapper function names (rdmaxcel_cu*) are intentionally NOT changed. + result - .replace("libcuda.so.1", "libamdhip64.so") - .replace("rdmaxcel_cuMemGetHandleForAddressRange", "rdmaxcel_hipMemGetHandleForAddressRange") - .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)") - .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") - .replace("CUmemRangeHandleType", "hipMemRangeHandleType") } // ============================================================================= -// ROCm 6.x specific patches (HSA dmabuf) +// ROCm 6.x patches (HSA dmabuf workaround) // ============================================================================= +/// Patch rdmaxcel_hip.cpp for ROCm 6.x +/// Uses HSA hsa_amd_portable_export_dmabuf instead of hipMemGetHandleForAddressRange fn patch_rdmaxcel_cpp_rocm6(content: &str) -> String { let mut result = content .replace("#include ", @@ -240,166 +303,136 @@ fn patch_rdmaxcel_cpp_rocm6(content: &str) -> String { .replace("CUDAAllocatorConfig::", "HIPAllocatorConfig::") .replace("hipDeviceAttributePciDomainId", "hipDeviceAttributePciDomainID") .replace("static_cast", "reinterpret_cast") - .replace("static_cast", "reinterpret_cast") - .replace("CUDA_SUCCESS", "hipSuccess") - .replace("CUdevice device", "hipDevice_t device") - .replace("cuDeviceGet(&device", "hipDeviceGet(&device") - .replace("cuDeviceGetAttribute", "hipDeviceGetAttribute") - .replace("cuPointerGetAttribute", "hipPointerGetAttribute") - .replace("CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", "hipDeviceAttributePciBusId") - .replace("CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", "hipDeviceAttributePciDeviceId") - .replace("CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", "hipDeviceAttributePciDomainID") - .replace("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") - .replace("CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD", "0 /* HSA dmabuf */"); - - // Convert cuMemGetHandleForAddressRange to HSA with correct argument order - result = result.replace("cuMemGetHandleForAddressRange(", "hsa_amd_portable_export_dmabuf("); - - // Fix argument order for HSA call (ptr, size, fd, flags) vs CUDA (fd, ptr, size, type, flags) - result = result.replace( - "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(start_addr),\n total_size,\n 0 /* HSA dmabuf */,\n 0);", - "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(start_addr),\n total_size,\n &fd,\n nullptr);" - ); - result = result.replace( - "hsa_amd_portable_export_dmabuf(\n &fd,\n reinterpret_cast(chunk_start),\n chunk_size,\n 0 /* HSA dmabuf */,\n 0);", - "hsa_amd_portable_export_dmabuf(\n reinterpret_cast(chunk_start),\n chunk_size,\n &fd,\n nullptr);" - ); + .replace("static_cast", "reinterpret_cast"); + + // Apply constant and type replacements + result = apply_replacements(&result, CUDA_CONSTANT_REPLACEMENTS); + result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + // For ROCm 6.x, replace the dmabuf constant with HSA placeholder + result = result.replace("hipMemRangeHandleTypeDmaBufFd", "0 /* HSA dmabuf */"); + result - .replace("CUresult cu_result", "hsa_status_t hsa_result") - .replace("hipError_t cu_result", "hsa_status_t hsa_result") - .replace("cu_result != hipSuccess", "hsa_result != HSA_STATUS_SUCCESS") - .replace("if (cu_result", "if (hsa_result") - .replace("hipPointerAttribute::device", "HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL") } +/// Patch driver_api_hip.h for ROCm 6.x +/// Add HSA includes and fix types. Do NOT rename wrapper functions. fn patch_driver_api_h_rocm6(content: &str) -> String { - let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); - - // Add HSA includes if not present + let mut result = content.to_string(); + + // Add HSA includes if !result.contains("#include ") { result = result.replace( "#include ", "#include \n#include \n#include " ); } - - // Replace CUDA handle function declaration with HSA version - let old_decl = "hipError_t rdmaxcel_cuMemGetHandleForAddressRange(\n int* handle,\n hipDeviceptr_t dptr,\n size_t size,\n CUmemRangeHandleType handleType,\n unsigned long long flags);"; - let new_decl = "hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf(\n void* ptr,\n size_t size,\n int* fd,\n uint64_t* flags);"; - result = result.replace(old_decl, new_decl); - result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x placeholder */"); - - // Add CUDA-compatible wrapper declaration for existing callers - result.push_str(r#" - -// CUDA-compatible wrapper for code expecting cuMemGetHandleForAddressRange -hipError_t rdmaxcel_cuMemGetHandleForAddressRange( - int* handle, - hipDeviceptr_t dptr, - size_t size, - int handleType, - unsigned long long flags); -"#); + + // Apply type replacements only - do NOT rename wrapper function declarations + result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + result } +/// Patch driver_api_hip.cpp for ROCm 6.x +/// Converts internal HIP calls and replaces hipMemGetHandleForAddressRange with HSA fn patch_driver_api_cpp_rocm6(content: &str) -> String { - let mut result = apply_replacements(content, WRAPPER_REPLACEMENTS); - result = apply_replacements(&result, DRIVER_API_PTR_REPLACEMENTS); - result = apply_replacements(&result, DRIVER_API_MACRO_REPLACEMENTS); - + let mut result = content.to_string(); + // Add HSA includes result = result.replace( "#include \"driver_api_hip.h\"", "#include \"driver_api_hip.h\"\n#include \n#include " ); - + + // Fix library name and runtime result = result .replace("libcuda.so.1", "libamdhip64.so") - .replace("dstDevice, srcHost, ByteCount);", "dstDevice, const_cast(srcHost), ByteCount);") - .replace("->cuMemGetHandleForAddressRange_(", "->hipMemGetHandleForAddressRange_(") - .replace("_(cuMemGetHandleForAddressRange)", "_(hipMemGetHandleForAddressRange)"); - - // Replace the wrapper function with HSA version + .replace("cudaFree", "hipFree") + .replace("#include ", "#include "); + + // Apply macro entry replacements for dlsym lookups + result = apply_replacements(&result, MACRO_ENTRY_REPLACEMENTS); + + // Apply struct member access replacements + result = apply_replacements(&result, MEMBER_ACCESS_REPLACEMENTS); + + // Apply type replacements + result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + + // Fix const_cast for HtoD + result = result.replace( + "dstDevice, srcHost, ByteCount);", + "dstDevice, const_cast(srcHost), ByteCount);" + ); + + // For ROCm 6.x, hipMemGetHandleForAddressRange doesn't exist + // Remove it from the macro list and we'll add HSA wrapper separately + result = result.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange not available in ROCm 6.x */ \\" + ); + result = result.replace( + "_(hipMemGetHandleForAddressRange) \\", + "/* hipMemGetHandleForAddressRange not available in ROCm 6.x */ \\" + ); + + // Replace the wrapper implementation to use HSA + // The wrapper function name stays as rdmaxcel_cuMemGetHandleForAddressRange let old_wrapper = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( int* handle, hipDeviceptr_t dptr, size_t size, - CUmemRangeHandleType handleType, + hipMemRangeHandleType handleType, unsigned long long flags) { return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( handle, dptr, size, handleType, flags); }"#; - let hsa_impl = r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( - void* ptr, + + let hsa_wrapper = r#"// ROCm 6.x: Use HSA hsa_amd_portable_export_dmabuf instead of hipMemGetHandleForAddressRange +hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, size_t size, - int* fd, - uint64_t* flags) { - // Direct HSA call - will be replaced with dlopen version - return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); + hipMemRangeHandleType handleType, + unsigned long long flags) { + (void)handleType; + (void)flags; + hsa_status_t status = hsa_amd_portable_export_dmabuf( + reinterpret_cast(dptr), size, handle, nullptr); + return (status == HSA_STATUS_SUCCESS) ? hipSuccess : hipErrorUnknown; }"#; - result = result.replace(old_wrapper, hsa_impl); - - // Handle alternate wrapper pattern + + result = result.replace(old_wrapper, hsa_wrapper); + + // Also handle if the type wasn't converted yet let old_wrapper2 = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( int* handle, hipDeviceptr_t dptr, size_t size, CUmemRangeHandleType handleType, unsigned long long flags) { - return rdmaxcel::DriverAPI::get()->cuMemGetHandleForAddressRange_( + return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( handle, dptr, size, handleType, flags); }"#; - result = result.replace(old_wrapper2, hsa_impl); - - result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x placeholder */"); - - // Remove hipMemGetHandleForAddressRange from dlopen macro (we use HSA instead) - result = result.replace( - "_(hipMemGetHandleForAddressRange) \\", - "/* hipMemGetHandleForAddressRange - using HSA */ \\" - ); - result = result.replace( - "_(hipMemGetHandleForAddressRange) \\", - "/* hipMemGetHandleForAddressRange - using HSA */ \\" - ); + result = result.replace(old_wrapper2, hsa_wrapper); - // Add CUDA-compatible wrapper that translates to HSA - result.push_str(r#" - -// CUDA-compatible wrapper - translates cuMemGetHandleForAddressRange to HSA -hipError_t rdmaxcel_cuMemGetHandleForAddressRange( - int* handle, - hipDeviceptr_t dptr, - size_t size, - int handleType, - unsigned long long flags) { - (void)handleType; // ROCm 6.x only supports dmabuf - (void)flags; - hsa_status_t status = rdmaxcel_hsa_amd_portable_export_dmabuf( - reinterpret_cast(dptr), size, handle, nullptr); - return (status == HSA_STATUS_SUCCESS) ? hipSuccess : hipErrorUnknown; -} -"#); result } /// Apply dlopen patches to avoid link-time dependencies on HIP/HSA libraries. -/// -/// This is important because: -/// 1. hipFree must be called via dlopen'd pointer, not directly -/// 2. HSA functions must be loaded lazily to avoid libhsa-runtime64.so dependency fn patch_for_dlopen(content: &str) -> String { let mut result = content.to_string(); - // 1. Add hipFree to the dlopen macro list - result = result.replace( - "_(hipDrvGetErrorString)", - "_(hipDrvGetErrorString) \\\n _(hipFree)" - ); + // Add hipFree to dlopen macro list if not already there + if !result.contains("_(hipFree)") { + result = result.replace( + "_(hipDrvGetErrorString)", + "_(hipDrvGetErrorString) \\\n _(hipFree)" + ); + } - // 2. Reorder DriverAPI::get() to create singleton first, then call dlopen'd hipFree + // Reorder DriverAPI::get() to create singleton first, then call hipFree via dlopen result = result.replace( r#"DriverAPI* DriverAPI::get() { // Ensure we have a valid CUDA context for this thread @@ -415,48 +448,5 @@ fn patch_for_dlopen(content: &str) -> String { }"# ); - // 3. Replace direct HSA call with lazy dlopen version - result = result.replace( - r#"hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( - void* ptr, - size_t size, - int* fd, - uint64_t* flags) { - // Direct HSA call - will be replaced with dlopen version - return hsa_amd_portable_export_dmabuf(ptr, size, fd, flags); -}"#, - r#"// Lazy-loaded HSA function - avoids link-time dependency on libhsa-runtime64.so -static decltype(&hsa_amd_portable_export_dmabuf) g_hsa_export_dmabuf = nullptr; - -hsa_status_t rdmaxcel_hsa_amd_portable_export_dmabuf( - void* ptr, - size_t size, - int* fd, - uint64_t* flags) { - if (!g_hsa_export_dmabuf) { - // Try RTLD_NOLOAD first (library may already be loaded by HIP runtime) - void* handle = dlopen("libhsa-runtime64.so", RTLD_LAZY | RTLD_NOLOAD); - if (!handle) handle = dlopen("libhsa-runtime64.so", RTLD_LAZY); - if (!handle) { - throw std::runtime_error( - std::string("[RdmaXcel] Failed to load libhsa-runtime64.so: ") + dlerror()); - } - g_hsa_export_dmabuf = reinterpret_cast( - dlsym(handle, "hsa_amd_portable_export_dmabuf")); - if (!g_hsa_export_dmabuf) { - throw std::runtime_error( - std::string("[RdmaXcel] Symbol not found: hsa_amd_portable_export_dmabuf: ") + dlerror()); - } - } - return g_hsa_export_dmabuf(ptr, size, fd, flags); -}"# - ); - - // 4. Update the CUDA-compat wrapper to use our dlopen'd function - result = result.replace( - "hsa_status_t status = hsa_amd_portable_export_dmabuf(", - "hsa_status_t status = rdmaxcel_hsa_amd_portable_export_dmabuf(" - ); - result } diff --git a/rdmaxcel-sys/src/lib.rs b/rdmaxcel-sys/src/lib.rs index 21f194046..4aee2c99d 100644 --- a/rdmaxcel-sys/src/lib.rs +++ b/rdmaxcel-sys/src/lib.rs @@ -148,6 +148,10 @@ pub use inner::*; // ROCm/HIP Compatibility Aliases // ============================================================================= // These allow monarch_rdma to use CUDA names transparently on ROCm builds. +// +// IMPORTANT: The C++ wrapper functions keep their rdmaxcel_cu* names for API +// stability on both ROCm 6.x and ROCm 7+. Only the internal implementations +// differ (HSA vs native HIP). // --- Basic Type Aliases --- #[cfg(any(rocm_6_x, rocm_7_plus))] @@ -207,72 +211,75 @@ pub const CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD: i32 = 0; #[cfg(rocm_7_plus)] pub use inner::hipMemRangeHandleTypeDmaBufFd as CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD; +// ============================================================================= +// Driver API Wrapper Functions +// ============================================================================= +// The C++ exports rdmaxcel_cu* names for both ROCm 6.x and 7+. +// These are re-exported directly without renaming. + // --- Driver Init/Device Functions --- #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipInit as rdmaxcel_cuInit; +pub use inner::rdmaxcel_cuInit; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipDeviceGet as rdmaxcel_cuDeviceGet; +pub use inner::rdmaxcel_cuDeviceGet; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipDeviceGetCount as rdmaxcel_cuDeviceGetCount; +pub use inner::rdmaxcel_cuDeviceGetCount; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipPointerGetAttribute as rdmaxcel_cuPointerGetAttribute; +pub use inner::rdmaxcel_cuPointerGetAttribute; // --- Context Functions --- #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipCtxCreate as rdmaxcel_cuCtxCreate_v2; +pub use inner::rdmaxcel_cuCtxCreate_v2; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipCtxSetCurrent as rdmaxcel_cuCtxSetCurrent; +pub use inner::rdmaxcel_cuCtxSetCurrent; // --- Error Handling Functions --- #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipGetErrorString as rdmaxcel_cuGetErrorString; +pub use inner::rdmaxcel_cuGetErrorString; // --- Memory Management Functions --- #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemGetAllocationGranularity as rdmaxcel_cuMemGetAllocationGranularity; +pub use inner::rdmaxcel_cuMemGetAllocationGranularity; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemCreate as rdmaxcel_cuMemCreate; +pub use inner::rdmaxcel_cuMemCreate; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemAddressReserve as rdmaxcel_cuMemAddressReserve; +pub use inner::rdmaxcel_cuMemAddressReserve; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemMap as rdmaxcel_cuMemMap; +pub use inner::rdmaxcel_cuMemMap; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemSetAccess as rdmaxcel_cuMemSetAccess; +pub use inner::rdmaxcel_cuMemSetAccess; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemUnmap as rdmaxcel_cuMemUnmap; +pub use inner::rdmaxcel_cuMemUnmap; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemAddressFree as rdmaxcel_cuMemAddressFree; +pub use inner::rdmaxcel_cuMemAddressFree; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemRelease as rdmaxcel_cuMemRelease; +pub use inner::rdmaxcel_cuMemRelease; // --- Memory Copy/Set Functions --- #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemcpyHtoD as rdmaxcel_cuMemcpyHtoD_v2; +pub use inner::rdmaxcel_cuMemcpyHtoD_v2; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemcpyDtoH as rdmaxcel_cuMemcpyDtoH_v2; +pub use inner::rdmaxcel_cuMemcpyDtoH_v2; #[cfg(any(rocm_6_x, rocm_7_plus))] -pub use inner::rdmaxcel_hipMemsetD8 as rdmaxcel_cuMemsetD8_v2; +pub use inner::rdmaxcel_cuMemsetD8_v2; // --- Dmabuf Function --- -// ROCm 7+: direct alias to HIP function -#[cfg(rocm_7_plus)] -pub use inner::rdmaxcel_hipMemGetHandleForAddressRange as rdmaxcel_cuMemGetHandleForAddressRange; - -// ROCm 6.x: uses the CUDA-compatible wrapper we added in build.rs that internally calls HSA -#[cfg(rocm_6_x)] +// Both ROCm 6.x and 7+ export rdmaxcel_cuMemGetHandleForAddressRange from C++ +// (ROCm 6.x uses HSA internally, ROCm 7+ uses native hipMemGetHandleForAddressRange) +#[cfg(any(rocm_6_x, rocm_7_plus))] pub use inner::rdmaxcel_cuMemGetHandleForAddressRange; // ============================================================================= From 945ba678915b63669dc5a28a955edf3cbc11226e Mon Sep 17 00:00:00 2001 From: Zachary Streeter Date: Fri, 19 Dec 2025 20:38:33 +0000 Subject: [PATCH 12/12] tested on both rocm 6.0 and 7.0 and builds succesfull --- build_utils/src/rocm.rs | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/build_utils/src/rocm.rs b/build_utils/src/rocm.rs index 70d1ccd15..cb419d0f6 100644 --- a/build_utils/src/rocm.rs +++ b/build_utils/src/rocm.rs @@ -305,9 +305,15 @@ fn patch_rdmaxcel_cpp_rocm6(content: &str) -> String { .replace("static_cast", "reinterpret_cast") .replace("static_cast", "reinterpret_cast"); - // Apply constant and type replacements + // Apply constant replacements result = apply_replacements(&result, CUDA_CONSTANT_REPLACEMENTS); - result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + + // Apply type replacements - but NOT hipMemRangeHandleType (doesn't exist in ROCm 6.x) + result = result.replace("CUresult", "hipError_t"); + result = result.replace("CUdevice device", "hipDevice_t device"); + + // ROCm 6.x doesn't have hipMemRangeHandleType - use int as placeholder + result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x: no hipMemRangeHandleType */"); // For ROCm 6.x, replace the dmabuf constant with HSA placeholder result = result.replace("hipMemRangeHandleTypeDmaBufFd", "0 /* HSA dmabuf */"); @@ -328,8 +334,12 @@ fn patch_driver_api_h_rocm6(content: &str) -> String { ); } - // Apply type replacements only - do NOT rename wrapper function declarations - result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + // Apply type replacements - but NOT hipMemRangeHandleType (doesn't exist in ROCm 6.x) + result = result.replace("CUresult", "hipError_t"); + result = result.replace("CUdevice device", "hipDevice_t device"); + + // ROCm 6.x doesn't have hipMemRangeHandleType - use int as placeholder + result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x: no hipMemRangeHandleType */"); result } @@ -357,8 +367,12 @@ fn patch_driver_api_cpp_rocm6(content: &str) -> String { // Apply struct member access replacements result = apply_replacements(&result, MEMBER_ACCESS_REPLACEMENTS); - // Apply type replacements - result = apply_replacements(&result, CUDA_TYPE_REPLACEMENTS); + // Apply type replacements - but NOT hipMemRangeHandleType (doesn't exist in ROCm 6.x) + result = result.replace("CUresult", "hipError_t"); + result = result.replace("CUdevice device", "hipDevice_t device"); + + // ROCm 6.x doesn't have hipMemRangeHandleType - use int as placeholder + result = result.replace("CUmemRangeHandleType", "int /* ROCm 6.x: no hipMemRangeHandleType */"); // Fix const_cast for HtoD result = result.replace( @@ -394,7 +408,7 @@ hipError_t rdmaxcel_cuMemGetHandleForAddressRange( int* handle, hipDeviceptr_t dptr, size_t size, - hipMemRangeHandleType handleType, + int handleType, unsigned long long flags) { (void)handleType; (void)flags; @@ -416,6 +430,18 @@ hipError_t rdmaxcel_cuMemGetHandleForAddressRange( handle, dptr, size, handleType, flags); }"#; result = result.replace(old_wrapper2, hsa_wrapper); + + // Handle if the type was already replaced with int placeholder + let old_wrapper3 = r#"hipError_t rdmaxcel_cuMemGetHandleForAddressRange( + int* handle, + hipDeviceptr_t dptr, + size_t size, + int /* ROCm 6.x: no hipMemRangeHandleType */ handleType, + unsigned long long flags) { + return rdmaxcel::DriverAPI::get()->hipMemGetHandleForAddressRange_( + handle, dptr, size, handleType, flags); +}"#; + result = result.replace(old_wrapper3, hsa_wrapper); result }