Skip to content

Commit

Permalink
#260 Add lib paths for cudnn dynamic linking. Lib names for dynamic l…
Browse files Browse the repository at this point in the history
…oading
  • Loading branch information
coreylowman committed Jun 26, 2024
1 parent 4ecd0b6 commit e4a6d16
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
54 changes: 34 additions & 20 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn main() {
println!("cargo:rustc-env=CUDA_MINOR_VERSION={minor}");

#[cfg(feature = "dynamic-linking")]
dynamic_linking();
dynamic_linking(major, minor);
}

#[allow(unused)]
Expand Down Expand Up @@ -81,10 +81,11 @@ fn cuda_version_from_build_system() -> (usize, usize) {
}

#[allow(unused)]
fn dynamic_linking() {
fn dynamic_linking(major: usize, minor: usize) {
let candidates: Vec<PathBuf> = root_candidates().collect();

let toolkit_root = root_candidates()
let toolkit_root = candidates
.iter()
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap_or_else(|| {
panic!(
Expand All @@ -93,22 +94,26 @@ fn dynamic_linking() {
)
});

for path in lib_candidates(&toolkit_root) {
for path in lib_candidates(&toolkit_root, major, minor) {

Check failure on line 97 in build.rs

View workflow job for this annotation

GitHub Actions / clippy

this expression creates a reference which is immediately dereferenced by the compiler
println!("cargo:rustc-link-search=native={}", path.display());
}

#[cfg(feature = "cudnn")]
{
let cudnn_root = root_candidates()
.find(|path| path.join("include").join("cudnn.h").is_file())
let cudnn_root = candidates
.iter()
.find(|path| {
path.join("include").join("cudnn.h").is_file()
|| path.join("include").join(std::format!("{major}.{minor}")).join("cudnn.h").is_file()
})
.unwrap_or_else(|| {
panic!(
"Unable to find `include/cudnn.h` under any of: {:?}. Set the `CUDNN_LIB` environment variable to `$CUDNN_LIB/include/cudnn.h` to override path.",
"Unable to find `include/cudnn.h` or `include/{major}.{minor}/cudnn.h` under any of: {:?}. Set the `CUDNN_LIB` environment variable to override path, or turn off dynamic linking (to enable dynamic loading).",
candidates
)
});

for path in lib_candidates(&cudnn_root) {
for path in lib_candidates(&cudnn_root, major, minor) {

Check failure on line 116 in build.rs

View workflow job for this annotation

GitHub Actions / clippy

this expression creates a reference which is immediately dereferenced by the compiler
println!("cargo:rustc-link-search=native={}", path.display());
}
}
Expand Down Expand Up @@ -148,28 +153,37 @@ fn root_candidates() -> impl Iterator<Item = PathBuf> {
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/Program Files/NVIDIA",
"C:/CUDA",
// See issue #260
"C:/Program Files/NVIDIA/CUDNN/v9.2",
"C:/Program Files/NVIDIA/CUDNN/v9.1",
"C:/Program Files/NVIDIA/CUDNN/v9.0",
];
let roots = roots.into_iter().map(Into::into);
env_vars.chain(roots).map(Into::<PathBuf>::into)
}

#[allow(unused)]
fn lib_candidates(root: &Path) -> Vec<PathBuf> {
fn lib_candidates(root: &Path, major: usize, minor: usize) -> Vec<PathBuf> {
[
"lib",
"lib/x64",
"lib/Win32",
"lib/x86_64",
"lib/x86_64-linux-gnu",
"lib64",
"lib64/stubs",
"targets/x86_64-linux",
"targets/x86_64-linux/lib",
"targets/x86_64-linux/lib/stubs",
"lib".into(),
"lib/x64".into(),
"lib/Win32".into(),
"lib/x86_64".into(),
"lib/x86_64-linux-gnu".into(),
"lib64".into(),
"lib64/stubs".into(),
"targets/x86_64-linux".into(),
"targets/x86_64-linux/lib".into(),
"targets/x86_64-linux/lib/stubs".into(),
// see issue #260
std::format!("lib/{major}.{minor}/x64"),
// see issue #260
std::format!("lib/{major}.{minor}/x86_64"),
]
.iter()
.map(|&p| root.join(p))
.map(|p| root.join(p))
.filter(|p| p.is_dir())
.collect()
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ pub(crate) fn get_lib_name_candidates(lib_name: &str) -> std::vec::Vec<std::stri
std::format!("{lib_name}{pointer_width}_10"),
// See issue #246
std::format!("{lib_name}{pointer_width}_{major}0_0"),
// See issue #260
std::format!("{lib_name}{pointer_width}_9"),
]
.into()
}

0 comments on commit e4a6d16

Please sign in to comment.