Skip to content

Commit

Permalink
Adding cuda 12.2 versions of sys
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Mar 13, 2024
1 parent ed3bb3e commit ae1bfbd
Show file tree
Hide file tree
Showing 41 changed files with 25,514 additions and 72 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ keywords = [
features = ["ci-check", "f16", "cudnn"]

[features]
default = ["std", "driver", "nvrtc", "cublas", "curand"]
default = ["std", "driver", "nvrtc", "cublas", "curand", "cuda_version_11_8"]
cuda_version_11_8 = []
cuda_version_12_2 = []
nvrtc = []
driver = ["nvrtc"]
cublas = ["driver"]
Expand Down
16 changes: 0 additions & 16 deletions src/cublas/bindgen.sh

This file was deleted.

19 changes: 19 additions & 0 deletions src/cublas/sys/bindgen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
set -exu

bindgen \
--allowlist-var="^CUDA_VERSION" \
--allowlist-type="^cublas.*" \
--allowlist-function="^cublas.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--use-core \
wrapper.h -- -I/usr/local/cuda/include \
> tmp.rs

CUDA_VERSION=$(cat tmp.rs | grep "CUDA_VERSION" | awk '{ print $6 }' | sed 's/.$//')
mv tmp.rs sys_${CUDA_VERSION}.rs
9 changes: 9 additions & 0 deletions src/cublas/sys/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#[cfg(feature = "cuda_version_11_8")]
mod sys_11080;
#[cfg(feature = "cuda_version_11_8")]
pub use sys_11080::*;

#[cfg(feature = "cuda_version_12_2")]
mod sys_12020;
#[cfg(feature = "cuda_version_12_2")]
pub use sys_12020::*;
File renamed without changes.
Loading

0 comments on commit ae1bfbd

Please sign in to comment.