Skip to content

Commit

Permalink
fix(sys): more robust CUDA version check. fixes #234
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jul 7, 2024
1 parent 3dec017 commit bb57252
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions ort-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use std::{
env, fs,
path::{Path, PathBuf}
path::{Path, PathBuf},
process::Command
};

#[allow(unused)]
const ONNXRUNTIME_VERSION: &str = "1.18.1";

const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION";
const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE";
#[cfg(feature = "download-binaries")]
Expand Down Expand Up @@ -46,29 +50,6 @@ fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static
.map(|c| (c[2], c[3]))
}

fn lib_exists(name: &str) -> bool {
#[cfg(any(target_family = "windows", unix))]
let lib_str = std::ffi::CString::new(name).unwrap();
// note that we're not performing any cleanup here because this is a short lived build script; the OS will clean it up
// for us when we finish
#[cfg(target_family = "windows")]
return unsafe {
extern "C" {
fn LoadLibraryA(lplibfilename: *const std::ffi::c_char) -> isize;
}
LoadLibraryA(lib_str.as_ptr()) != 0
};
#[cfg(unix)]
return unsafe {
extern "C" {
fn dlopen(file: *const std::ffi::c_char, mode: std::ffi::c_int) -> *const std::ffi::c_void;
}
!dlopen(lib_str.as_ptr(), 1).is_null()
};
#[cfg(not(any(target_family = "windows", unix)))]
return false;
}

#[cfg(feature = "download-binaries")]
fn hex_str_to_bytes(c: impl AsRef<[u8]>) -> Vec<u8> {
fn nibble(c: u8) -> u8 {
Expand Down Expand Up @@ -136,7 +117,7 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) {
#[cfg(target_os = "linux")]
{
let main_dy = lib_dir.join("libonnxruntime.so");
let versioned_dy = out_dir.join("libonnxruntime.so.1.17.3");
let versioned_dy = out_dir.join(format!("libonnxruntime.so.{}", ONNXRUNTIME_VERSION));
if main_dy.exists() && !versioned_dy.exists() {
if versioned_dy.is_symlink() {
fs::remove_file(&versioned_dy).unwrap();
Expand Down Expand Up @@ -347,15 +328,38 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
feature_set.push("train");
}
if cfg!(any(feature = "cuda", feature = "tensorrt")) {
if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") {
feature_set.push("cu11");
} else {
feature_set.push("cu12");
match env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() {
Ok("11") => feature_set.push("cu11"),
Ok("12") => feature_set.push("cu12"),
_ => {
let mut success = false;
if let Ok(nvcc_output) = Command::new("nvcc").arg("--version").output() {
if nvcc_output.status.success() {
let stdout = String::from_utf8_lossy(&nvcc_output.stdout);
let version_line = stdout.lines().nth(3).unwrap();
let release_section = version_line.split(", ").nth(1).unwrap();
let version_number = release_section.split(' ').nth(1).unwrap();
if version_number.starts_with("12") {
feature_set.push("cu12");
} else {
feature_set.push("cu11");
}
success = true;
}
}

if !success {
println!("cargo:warning=nvcc call did not succeed. falling back to CUDA 12");
// fallback to CUDA 12.
feature_set.push("cu12");
}
}
}
} else if cfg!(feature = "rocm") {
feature_set.push("rocm");
}
let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() };
println!("selected feature set: {feature_set}");
let mut dist = find_dist(&target, &feature_set);
if dist.is_none() && feature_set != "none" {
dist = find_dist(&target, "none");
Expand Down Expand Up @@ -411,6 +415,13 @@ fn prepare_libort_dir() -> (PathBuf, bool) {
fn try_setup_with_pkg_config() -> bool {
match pkg_config::Config::new().probe("libonnxruntime") {
Ok(lib) => {
let expected_minor = ONNXRUNTIME_VERSION.split('.').nth(1).unwrap().parse::<usize>().unwrap();
let got_minor = lib.version.split('.').nth(1).unwrap().parse::<usize>().unwrap();
if got_minor < expected_minor {
println!("libonnxruntime provided by pkg-config is out of date, so it will be ignored - expected {}, got {}", ONNXRUNTIME_VERSION, lib.version);
return false;
}

// Setting the link paths
for path in lib.link_paths {
println!("cargo:rustc-link-search=native={}", path.display());
Expand Down

0 comments on commit bb57252

Please sign in to comment.