Skip to content

Commit

Permalink
[CUDA] Preload dependent DLLs (#23674)
Browse files Browse the repository at this point in the history
### Description

Changes:
(1) Pass --cuda_version in packaging pipeline to build wheel command
line so that cuda_version can be saved. Note that cuda_version is also
required for generating extra_require for
#23659.
(2) Update steup.py and onnxruntime_validation.py to save cuda version
to capi/build_and_package_info.py.
(3) Add a helper function to preload dependent DLLs (MSVC, CUDA, CUDNN)
in `__init__.py`. First we will try to load DLLs from nvidia site
packages, then try load remaining DLLs with default path settings.

```
import onnxruntime
onnxruntime.preload_dlls()
```

To show loaded DLLs, set `verbose=True`. It is also possible to disable
loading some types of DLLs like:
```
onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
```

#### PyTorch and onnxruntime in Windows

When working with pytorch, onnxruntime will reuse the CUDA and cuDNN
DLLs loaded by pytorch as long as CUDA and cuDNN major versions are
compatible. Preload DLLs actually might cause issues (see example 2 and
3 below) in Windows.

Example 1: onnxruntime and torch can work together easily. 
```
>>> import torch
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
>>> session.get_providers()
['CUDAExecutionProvider', 'CPUExecutionProvider']
```

Example 2: Use preload_dlls after `import torch` is not necessary.
Unfortunately, it seems that multiple DLLs of same filename are loaded.
They can be used in parallel but not ideal since more memory is used.
```
>>> import torch
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
```

Example 3: Use preload_dlls before `import torch` might cause torch
import error in Windows. Later we may provide an option to load DLLs
from torch directory to avoid this issue.
```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll
D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll
D:\anaconda3\envs\py310\msvcp140.dll
D:\anaconda3\envs\py310\vcruntime140_1.dll
D:\anaconda3\envs\py310\msvcp140_1.dll
D:\anaconda3\envs\py310\vcruntime140.dll
>>> import torch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\anaconda3\envs\py310\lib\site-packages\torch\__init__.py", line 137, in <module>
    raise err
OSError: [WinError 127] The specified procedure could not be found. Error loading "D:\anaconda3\envs\py310\lib\site-packages\torch\lib\cudnn_adv64_9.dll" or one of its dependencies.
```

#### PyTorch and onnxruntime in Linux

In Linux, since pytorch uses nvidia site packages for CUDA and cuDNN
DLLs. Preload DLLs consistently loads same set of DLLs, and it could
help maintaining.

```
>>> import onnxruntime
>>> onnxruntime.preload_dlls(verbose=True)
----List of loaded DLLs----
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12
>>> import torch
>>> torch.rand(3, 3).cuda()
tensor([[0.4619, 0.0279, 0.2092],
        [0.0416, 0.6782, 0.5889],
        [0.9988, 0.9092, 0.7982]], device='cuda:0')
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> session.get_providers()
['CUDAExecutionProvider', 'CPUExecutionProvider']
```

```
>>> import torch
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9
/anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12
```

Without preloading DLLs, onnxruntime will load CUDA and cuDNN DLLs based
on `LD_LIBRARY_PATH`. Torch will reuse the same DLLs loaded by
onnxruntime:
```
>>> import onnxruntime
>>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"])
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
>>> import torch
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
>>> torch.rand(3, 3).cuda()
tensor([[0.2233, 0.9194, 0.8078],
        [0.0906, 0.2884, 0.3655],
        [0.6249, 0.2904, 0.4568]], device='cuda:0')
>>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True)
----List of loaded DLLs----
/cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61
/cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41
/cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55
/cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14
/cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14
/cudnn9.7/lib/libcudnn_graph.so.9.7.0
/cudnn9.7/lib/libcudnn.so.9.7.0
/cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57
```

### Motivation and Context
In many reported issues of import onnxruntime failure, the root cause is
dependent DLLs missing or not in path. This change will make it easier
to resolve those issues.

This is based on Jian's PR
#22506 with extra change to
load msvc dlls.

#23659 can be used to
install CUDA/cuDNN dlls to site packages. Example command line after
next official release 1.21:
```
pip install onnxruntime-gpu[cuda,cudnn]
```

If user installed pytorch in Linux, those DLLs are usually installed
together with torch.
  • Loading branch information
tianleiwu authored Feb 15, 2025
1 parent 4f66610 commit c7aa9a7
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 63 deletions.
96 changes: 95 additions & 1 deletion onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,103 @@
except ImportError:
pass

from onnxruntime.capi.onnxruntime_validation import cuda_version, package_name, version # noqa: F401

package_name, version, cuda_version = onnxruntime_validation.get_package_name_and_version_info()

if version:
__version__ = version

onnxruntime_validation.check_distro_info()


def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, verbose: bool = False):
import ctypes
import os
import platform
import site

if platform.system() not in ["Windows", "Linux"]:
return

is_windows = platform.system() == "Windows"
if is_windows and msvc:
try:
ctypes.CDLL("vcruntime140.dll")
ctypes.CDLL("msvcp140.dll")
if platform.machine() != "ARM64":
ctypes.CDLL("vcruntime140_1.dll")
except OSError:
print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")

if cuda_version and cuda_version.startswith("12.") and (cuda or cudnn):
# Paths are relative to nvidia root in site packages.
if is_windows:
cuda_dll_paths = [
("cublas", "bin", "cublasLt64_12.dll"),
("cublas", "bin", "cublas64_12.dll"),
("cufft", "bin", "cufft64_11.dll"),
("cuda_runtime", "bin", "cudart64_12.dll"),
]
cudnn_dll_paths = [
("cudnn", "bin", "cudnn_graph64_9.dll"),
("cudnn", "bin", "cudnn64_9.dll"),
]
else: # Linux
# cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
cuda_dll_paths = [
("cublas", "lib", "libcublasLt.so.12"),
("cublas", "lib", "libcublas.so.12"),
("cuda_nvrtc", "lib", "libnvrtc.so.12"),
("curand", "lib", "libcurand.so.10"),
("cufft", "lib", "libcufft.so.11"),
("cuda_runtime", "lib", "libcudart.so.12"),
]
cudnn_dll_paths = [
("cudnn", "lib", "libcudnn_graph.so.9"),
("cudnn", "lib", "libcudnn.so.9"),
]

# Try load DLLs from nvidia site packages.
dll_paths = (cuda_dll_paths if cuda else []) + (cudnn_dll_paths if cudnn else [])
loaded_dlls = []
for site_packages_path in reversed(site.getsitepackages()):
nvidia_path = os.path.join(site_packages_path, "nvidia")
if os.path.isdir(nvidia_path):
for relative_path in dll_paths:
dll_path = os.path.join(nvidia_path, *relative_path)
if os.path.isfile(dll_path):
try:
_ = ctypes.CDLL(dll_path)
loaded_dlls.append(relative_path[-1])
except Exception as e:
print(f"Failed to load {dll_path}: {e}")
break

# Try load DLLs with default path settings.
has_failure = False
for relative_path in dll_paths:
dll_filename = relative_path[-1]
if dll_filename not in loaded_dlls:
try:
_ = ctypes.CDLL(dll_filename)
except Exception as e:
has_failure = True
print(f"Failed to load {dll_filename}: {e}")

if has_failure:
print("Please follow https://onnxruntime.ai/docs/install/#cuda-and-cudnn to install CUDA and CuDNN.")

if verbose:

def is_target_dll(path: str):
target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", "vcruntime140", "msvcp140"]
return any(keyword in path for keyword in target_keywords)

import psutil

p = psutil.Process(os.getpid())
print("----List of loaded DLLs----")
for lib in p.memory_maps():
if is_target_dll(lib.path.lower()):
print(lib.path)
91 changes: 47 additions & 44 deletions onnxruntime/python/onnxruntime_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,28 @@ def check_distro_info():
)


def validate_build_package_info():
def get_package_name_and_version_info():
package_name = ""
version = ""
cuda_version = ""

try:
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name

try: # noqa: SIM105
from .build_and_package_info import cuda_version
except ImportError:
# cuda_version is optional. For example, cpu only package does not have the attribute.
pass
except Exception as e:
warnings.warn("WARNING: failed to collect package name and version info")
print(e)

return package_name, version, cuda_version


def check_training_module():
import_ortmodule_exception = None

has_ortmodule = False
Expand Down Expand Up @@ -96,48 +117,33 @@ def validate_build_package_info():
if not has_ortmodule:
import_ortmodule_exception = e

package_name = ""
version = ""
cuda_version = ""
# collect onnxruntime package name, version, and cuda version
package_name, version, cuda_version = get_package_name_and_version_info()

if has_ortmodule:
if has_ortmodule and cuda_version:
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name

try: # noqa: SIM105
from .build_and_package_info import cuda_version
except Exception:
pass

if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except Exception:
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")

# collection cuda library info from current environment.
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions

local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
print_build_package_info()
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
else:
# TODO: rcom
pass

# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except ImportError:
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")

# collection cuda library info from current environment.
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions

local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
print_build_package_info()
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
except Exception as e:
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
print(e)
Expand All @@ -146,6 +152,3 @@ def print_build_package_info():
raise import_ortmodule_exception

return has_ortmodule, package_name, version, cuda_version


has_ortmodule, package_name, version, cuda_version = validate_build_package_info()
31 changes: 16 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,21 +724,22 @@ def reformat_run_count(count_str):
with open(requirements_path) as f:
install_requires = f.read().splitlines()

if enable_training:

def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions

version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")
version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")

if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")
if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")

# cudart_versions are integers
# The cudart version used in building training packages in Linux.
# It is possible to parse version.json at cuda_home in build.py, then pass in the parameter directly.
if enable_training and platform.system().lower() == "linux":
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
Expand All @@ -751,10 +752,11 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
else "found multiple cudart libraries"
),
)
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")

save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)

save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)

extras_require = {}
if package_name == "onnxruntime-gpu" and is_cuda_version_12:
Expand All @@ -770,7 +772,6 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
],
}

# Setup
setup(
name=package_name,
version=version_number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ parameters:
- Release
- RelWithDebInfo
- MinSizeRel

- name: use_tensorrt
type: boolean
default: false
Expand Down Expand Up @@ -141,7 +141,7 @@ stages:
displayName: 'Build wheel'
inputs:
scriptPath: '$(Build.SourcesDirectory)\setup.py'
arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }}'
arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }} --cuda_version=${{ parameters.CudaVersion }}'
workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}'

- task: CopyFiles@2
Expand Down Expand Up @@ -195,7 +195,7 @@ stages:
TMPDIR: "$(Agent.TempDirectory)"

- powershell: |
python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq
Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate}
mkdir -p $(Agent.TempDirectory)\ort_test_data
Expand Down

0 comments on commit c7aa9a7

Please sign in to comment.