From d11e053412960fc677d45215a0b3f8d9e23fc53c Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Tue, 3 Oct 2023 15:53:09 +0900 Subject: [PATCH] Add option to specify the EP to use, enabling DML EP and others (#17490) ### Description Add DML EP to the acceptable provider list in the optimizer. ### Motivation and Context With DML EP, graph optimization was not performed in onnxruntime. --- .../stable_diffusion/optimize_pipeline.py | 10 ++++ .../python/tools/transformers/optimizer.py | 52 ++++++++++++++++--- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 4512c971ac27c..aef60a534608a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -144,6 +144,7 @@ def _optimize_sd_pipeline( opt_level=0, optimization_options=fusion_options, use_gpu=True, + provider=args.provider, ) if float16: @@ -168,6 +169,7 @@ def _optimize_sd_pipeline( optimize_by_onnxruntime( str(tmp_model_path), use_gpu=True, + provider=args.provider, optimized_model_path=str(ort_optimized_model_path), save_as_external_data=use_external_data_format, ) @@ -324,6 +326,14 @@ def parse_arguments(argv: Optional[List[str]] = None): ) parser.set_defaults(use_external_data_format=None) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use.", + ) + FusionOptions.add_arguments(parser) args = parser.parse_args(argv) diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 3f274eb6c835a..5ded027b36f74 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -69,6 +69,8 @@ def optimize_by_onnxruntime( save_as_external_data: bool = False, external_data_filename: str = "", external_data_file_threshold: int = 1024, + *, + provider: Optional[str] = None, ) -> str: """ Use onnxruntime to optimize model. @@ -82,6 +84,7 @@ def optimize_by_onnxruntime( save_as_external_data (bool): whether to save external data outside of ONNX model external_data_filename (str): name of external data file. If not provided, name is automatically created from ONNX model. external_data_file_threshold (int): threshold to decide whether to save tensor in ONNX model or in external data file + provider (str or None): execution provider to use if use_gpu Returns: optimized_model_path (str): the path of optimized model """ @@ -90,8 +93,12 @@ def optimize_by_onnxruntime( import onnxruntime - if use_gpu and set(onnxruntime.get_available_providers()).isdisjoint( - ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + if ( + use_gpu + and provider is None + and set(onnxruntime.get_available_providers()).isdisjoint( + ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + ) ): logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path @@ -138,17 +145,32 @@ def optimize_by_onnxruntime( kwargs["disabled_optimizers"] = disabled_optimizers if not use_gpu: - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=["CPUExecutionProvider"], **kwargs) + providers = ["CPUExecutionProvider"] + elif provider is not None: + if provider == "dml": + providers = ["DmlExecutionProvider"] + elif provider == "rocm": + providers = ["ROCMExecutionProvider"] + elif provider == "migraphx": + providers = ["MIGraphXExecutionProvider", "ROCMExecutionProvider"] + elif provider == "cuda": + providers = ["CUDAExecutionProvider"] + elif provider == "tensorrt": + providers = ["TensorrtExecutionProvider", "CUDAExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + + providers.append("CPUExecutionProvider") else: - gpu_ep = [] + providers = [] if torch_version.hip: - gpu_ep.append("MIGraphXExecutionProvider") - gpu_ep.append("ROCMExecutionProvider") + providers.append("MIGraphXExecutionProvider") + providers.append("ROCMExecutionProvider") else: - gpu_ep.append("CUDAExecutionProvider") + providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=gpu_ep, **kwargs) + onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -220,6 +242,8 @@ def optimize_model( use_gpu: bool = False, only_onnxruntime: bool = False, verbose: bool = False, + *, + provider: Optional[str] = None, ): """Optimize Model by OnnxRuntime and/or python fusion logic. @@ -257,6 +281,7 @@ def optimize_model( use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False. only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False. + provider (str, optional): execution provider to use if use_gpu. Defaults to None. Returns: object of an optimizer class. @@ -302,6 +327,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=opt_level, disabled_optimizers=disabled_optimizers, @@ -316,6 +342,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=1, disabled_optimizers=disabled_optimizers, @@ -423,6 +450,14 @@ def _parse_arguments(): ) parser.set_defaults(use_gpu=False) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use if use_gpu", + ) + parser.add_argument( "--only_onnxruntime", required=False, @@ -501,6 +536,7 @@ def main(): opt_level=args.opt_level, optimization_options=optimization_options, use_gpu=args.use_gpu, + provider=args.provider, only_onnxruntime=args.only_onnxruntime, )