Skip to content

Commit 0a17489

Browse files
apwojcikjeffdaily
authored andcommitted
Bundle MIGraphX with ROCm when built together (#47)
* create package for migraphx ep * add migrahx to the gpu providers for benchmark.py * remove rocm from migraphx perfs tests
1 parent 4004369 commit 0a17489

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

onnxruntime/python/tools/transformers/benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def run_onnxruntime(
117117
if (
118118
use_gpu
119119
and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers())
120+
and ("MIGraphXExecutionProvider" not in onnxruntime.get_available_providers())
120121
and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers())
121122
and ("DmlExecutionProvider" not in onnxruntime.get_available_providers())
122123
):

onnxruntime/test/perftest/ort_test_session.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,6 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
536536
} else if (provider_name_ == onnxruntime::kMIGraphXExecutionProvider) {
537537
#ifdef USE_MIGRAPHX
538538
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0));
539-
OrtROCMProviderOptions rocm_options;
540-
rocm_options.miopen_conv_exhaustive_search = performance_test_config.run_config.cudnn_conv_algo;
541-
rocm_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream;
542-
session_options.AppendExecutionProvider_ROCM(rocm_options);
543539
#else
544540
ORT_THROW("MIGraphX is not supported in this build\n");
545541
#endif

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def parse_arg_remove_string(argv, arg_name_equal):
6666
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
6767
is_rocm = True
6868
rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
69+
if parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
70+
is_migraphx = True
6971
elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"):
7072
is_migraphx = True
7173
elif parse_arg_remove_boolean(sys.argv, "--use_openvino"):
@@ -89,8 +91,10 @@ def parse_arg_remove_string(argv, arg_name_equal):
8991
elif parse_arg_remove_boolean(sys.argv, "--use_qnn"):
9092
package_name = "onnxruntime-qnn"
9193

92-
if is_rocm or is_migraphx:
94+
if is_rocm:
9395
package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly"
96+
elif is_migraphx:
97+
package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly"
9498

9599
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
96100
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686

tools/ci_build/build.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,6 +2264,8 @@ def build_python_wheel(
22642264
args.append("--use_rocm")
22652265
if rocm_version:
22662266
args.append(f"--rocm_version={rocm_version}")
2267+
if use_migraphx:
2268+
args.append("--use_migraphx")
22672269
elif use_migraphx:
22682270
args.append("--use_migraphx")
22692271
elif use_openvino:

0 commit comments

Comments
 (0)