From 67e0e69543432ada4cac970459df86f0ac12eac9 Mon Sep 17 00:00:00 2001 From: "Micah Chambers (minerva)" Date: Thu, 10 Apr 2025 16:11:04 -0700 Subject: [PATCH] add option to link to different tensorrt --- build.rs | 70 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/build.rs b/build.rs index 7206204..e1db1cb 100644 --- a/build.rs +++ b/build.rs @@ -1,37 +1,75 @@ -fn main() { - let cuda_path = std::env::var("CUDA_PATH").map(std::path::PathBuf::from); +fn search_for_path( + base_env: &str, + default_base: Option<&str>, + include_env: &str, + lib_env: &str, +) -> (std::path::PathBuf, std::path::PathBuf) { + let base_path = if let Some(default_base) = default_base { + std::env::var(base_env) + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| std::path::PathBuf::from(default_base)) + } else { + std::env::var(base_env) + .map(std::path::PathBuf::from) + .expect(&format!("Missing environment variable `{base_env}`.")) + }; - #[cfg(not(windows))] - let cuda_path = cuda_path.unwrap_or_else(|_| std::path::PathBuf::from("/usr/local/cuda")); - #[cfg(windows)] - let cuda_path = cuda_path.expect("Missing environment variable `CUDA_PATH`."); - - let cuda_include_path = std::env::var("CUDA_INCLUDE_PATH") + let include_path = std::env::var(include_env) .map(std::path::PathBuf::from) - .unwrap_or_else(|_| cuda_path.join("include")); + .unwrap_or_else(|_| base_path.join("include")); - let cuda_lib_path = std::env::var("CUDA_LIB_PATH") + let lib_path = std::env::var(lib_env) .map(std::path::PathBuf::from) .unwrap_or_else(|_| { #[cfg(not(windows))] { - cuda_path.join("lib64") + base_path.join("lib64") } #[cfg(windows)] { - cuda_path.join("lib").join("x64") + base_path.join("lib").join("x64") } }); + (include_path, lib_path) +} + +fn main() { + #[cfg(not(windows))] + let (cuda_include_path, cuda_lib_path) = search_for_path( + "CUDA_PATH", + Some("/usr/local/cuda"), + "CUDA_INCLUDE_PATH", + "CUDA_LIB_PATH", + ); + + #[cfg(windows)] + let (cuda_include_path, cuda_lib_path) = + search_for_path("CUDA_PATH", None, "CUDA_INCLUDE_PATH", "CUDA_LIB_PATH"); + + #[cfg(not(windows))] + let (tensorrt_include_path, tensorrt_lib_path) = search_for_path( + "TENSORRT_PATH", + Some("/usr/local/tensorrt"), + "TENSORRT_INCLUDE_PATH", + "TENSORRT_LIB_PATH", + ); + + #[cfg(windows)] + let (tensorrt_include_path, tensorrt_lib_path) = search_for_path( + "TENSORRT_PATH", + None, + "TENSORRT_INCLUDE_PATH", + "TENSORRT_LIB_PATH", + ); + let mut cpp_build_config = cpp_build::Config::new(); cpp_build_config.include(cuda_include_path); - #[cfg(not(windows))] - cpp_build_config.include("/usr/local/tensorrt/include"); + cpp_build_config.include(tensorrt_include_path); cpp_build_config.build("src/lib.rs"); println!("cargo:rustc-link-search={}", cuda_lib_path.display()); - #[cfg(not(windows))] - println!("cargo:rustc-link-search=/usr/local/tensorrt/lib64"); + println!("cargo:rustc-link-search={}", tensorrt_lib_path.display()); println!("cargo:rustc-link-lib=nvinfer"); println!("cargo:rustc-link-lib=nvonnxparser");