From b1f3135d888ce9153866fe13aa053095a9878ccc Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 3 Oct 2023 11:29:17 -0700 Subject: [PATCH] CUDA EP vs ROCM EP hipify audit Migrate all CUDA EP improvements and changes to ROCM EP. The process involves using hipify against all CUDA EP files (i.e. do not exclude any files from onnxruntime_rocm_hipify.cmake) then vimdiff compare them against the ROCM EP files that are under source control and pull in most changes. These changes include functional as well as formatting and makes comparing CUDA EP and ROCM EP easier, though it makes the PR diff somewhat less obvious due to formatting changes. - hipify audit of onnxruntime/core/providers/rocm, enable ops - Loop - Scan - hipify audit of onnxruntime/contrib_ops/rocm - fix contrib ops search implementation - enable more contrib ops - Affine - ComplexMul - ConvTransposeWithDynamicPads - Crop - DynamicSlice - FFT [Rfft, Irfft] - GreedySearch - ImageScaler - ParametricSoftplus - ScaledTanh - ThresholdRelu --- cmake/onnxruntime_providers.cmake | 3 +- cmake/onnxruntime_rocm_hipify.cmake | 27 - .../cpu/transformers/beam_search.cc | 12 +- .../cpu/transformers/beam_search.h | 6 +- .../cpu/transformers/beam_search_impl_gpt.h | 6 +- .../cpu/transformers/beam_search_impl_t5.h | 6 +- .../transformers/beam_search_impl_whisper.h | 6 +- .../transformers/generation_device_helper.h | 2 +- .../cpu/transformers/greedy_search.cc | 4 +- .../cpu/transformers/greedy_search.h | 6 +- .../cpu/transformers/greedy_search_impl_gpt.h | 6 +- .../contrib_ops/cpu/transformers/sampling.cc | 4 +- .../contrib_ops/cpu/transformers/sampling.h | 6 +- .../contrib_ops/rocm/rocm_contrib_kernels.cc | 144 +- .../core/providers/rocm/cu_inc/common.cuh | 20 +- onnxruntime/core/providers/rocm/fpgeneric.cu | 4 +- .../core/providers/rocm/gpu_data_transfer.cc | 34 +- .../core/providers/rocm/gpu_data_transfer.h | 4 +- .../core/providers/rocm/integer_gemm.cc | 21 +- onnxruntime/core/providers/rocm/math/einsum.h | 5 +- .../math/einsum_utils/einsum_auxiliary_ops.h | 13 +- .../core/providers/rocm/math/softmax.cc | 27 +- .../core/providers/rocm/math/softmax.h | 14 +- onnxruntime/core/providers/rocm/nn/conv.cc | 44 +- onnxruntime/core/providers/rocm/nn/conv.h | 11 +- .../core/providers/rocm/nn/conv_transpose.cc | 3 +- .../providers/rocm/reduction/reduction_ops.cc | 65 +- .../core/providers/rocm/rocm_allocator.cc | 5 +- .../core/providers/rocm/rocm_allocator.h | 3 +- onnxruntime/core/providers/rocm/rocm_call.cc | 2 +- .../providers/rocm/rocm_execution_provider.cc | 2077 ++++++++--------- .../providers/rocm/rocm_execution_provider.h | 25 +- .../rocm/rocm_execution_provider_info.cc | 6 +- onnxruntime/core/providers/rocm/rocm_fwd.h | 13 - onnxruntime/core/providers/rocm/rocm_kernel.h | 58 +- .../providers/rocm/rocm_provider_factory.cc | 71 +- .../providers/rocm/rocm_provider_factory.h | 12 +- .../core/providers/rocm/rocm_stream_handle.cc | 36 +- .../core/providers/rocm/rocm_stream_handle.h | 9 +- onnxruntime/core/providers/rocm/rocm_utils.cu | 9 +- .../providers/rocm/shared_inc/fast_divmod.h | 90 - .../providers/rocm/shared_inc/rocm_call.h | 4 + .../test/contrib_ops/element_wise_ops_test.cc | 26 +- onnxruntime/test/contrib_ops/fft_op_test.cc | 22 +- .../test/contrib_ops/greedy_search_test.cc | 14 +- tools/ci_build/amd_hipify.py | 20 + 46 files changed, 1501 insertions(+), 1504 deletions(-) delete mode 100644 onnxruntime/core/providers/rocm/rocm_fwd.h delete mode 100644 onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 96c05e5282bb5..14a1909cf4e80 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1516,6 +1516,7 @@ if (onnxruntime_USE_ROCM) find_package(hiprand REQUIRED) find_package(rocblas REQUIRED) find_package(MIOpen REQUIRED) + find_package(hipfft REQUIRED) # MIOpen version if(NOT DEFINED ENV{MIOPEN_PATH}) @@ -1554,7 +1555,7 @@ if (onnxruntime_USE_ROCM) find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB}) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cf71b6bcf7c7d..d34b58bced22c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" - "math/complex_mul.cc" - "math/complex_mul.h" - "math/complex_mul_impl.cu" - "math/complex_mul_impl.h" - "math/cufft_plan_cache.h" - "math/fft_ops.cc" - "math/fft_ops.h" - "math/fft_ops_impl.cu" - "math/fft_ops_impl.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files "quantization/qordered_ops/qordered_unary_ops.cc" "quantization/qordered_ops/qordered_unary_ops_impl.h" "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "tensor/crop.cc" - "tensor/crop.h" - "tensor/crop_impl.cu" - "tensor/crop_impl.h" - "tensor/dynamicslice.cc" - "tensor/image_scaler.cc" - "tensor/image_scaler.h" - "tensor/image_scaler_impl.cu" - "tensor/image_scaler_impl.h" - "transformers/greedy_search.cc" - "transformers/greedy_search.h" - "conv_transpose_with_dynamic_pads.cc" - "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" "cuda_contrib_kernels.h" "inverse.cc" @@ -114,10 +92,6 @@ endif() set(provider_excluded_files "atomic/common.cuh" - "controlflow/loop.cc" - "controlflow/loop.h" - "controlflow/scan.cc" - "controlflow/scan.h" "cu_inc/common.cuh" "math/einsum_utils/einsum_auxiliary_ops.cc" "math/einsum_utils/einsum_auxiliary_ops.h" @@ -165,7 +139,6 @@ set(provider_excluded_files "cuda_memory_check.h" "cuda_fence.cc" "cuda_fence.h" - "cuda_fwd.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..3761737994db7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -217,7 +217,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -240,7 +240,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_int32_func_, update_gpt_feeds_fp16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -271,7 +271,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -293,7 +293,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_, expand_buffer_float16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -320,7 +320,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -341,7 +341,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_, expand_buffer_float16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..ff2fe875678a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -66,7 +66,7 @@ class BeamSearch : public IControlFlowKernel { create_beam_scorer_func_ = create_beam_scorer_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func) { @@ -96,7 +96,7 @@ class BeamSearch : public IControlFlowKernel { expand_buffer_float16_func_ = expand_buffer_float16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; #endif @@ -115,7 +115,7 @@ class BeamSearch : public IControlFlowKernel { GenerationDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..50417667d887f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -46,7 +46,7 @@ class BeamSearchGpt : public BeamSearchBase { update_feeds_func_(update_feeds_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const void* cuda_device_prop, @@ -100,7 +100,7 @@ class BeamSearchGpt : public BeamSearchBase { GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; @@ -336,7 +336,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch // Increase sequence length after a new token is generated. ++current_length; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) // contains DecoderMaskedSelfAttention nodes if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..b3e80d9482e7f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -53,7 +53,7 @@ class BeamSearchT5 : public BeamSearchBase { expand_buffer_float16_func_(expand_buffer_float16_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, @@ -87,7 +87,7 @@ class BeamSearchT5 : public BeamSearchBase { // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif @@ -280,7 +280,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..b22a0cf44fefa 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -51,7 +51,7 @@ class BeamSearchWhisper : public BeamSearchBase { expand_buffer_float16_func_(expand_buffer_float16_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, @@ -85,7 +85,7 @@ class BeamSearchWhisper : public BeamSearchBase { // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif @@ -272,7 +272,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..90c06d0fd64d4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -33,7 +33,7 @@ enum DeviceCopyDirection { namespace GenerationDeviceHelper { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) using ReorderPastStateFunc = std::function, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -227,7 +227,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { init_greedy_state_fp16_func_, device_copy_func_, update_gpt_feeds_fp16_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h index a065255766e31..d49db65dbcdd8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h @@ -60,7 +60,7 @@ class GreedySearch : public IControlFlowKernel { init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) { reorder_past_state_func_ = reorder_past_state_func; } @@ -73,7 +73,7 @@ class GreedySearch : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; #endif @@ -90,7 +90,7 @@ class GreedySearch : public IControlFlowKernel { GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..8ae948bce85e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -60,7 +60,7 @@ class GreedySearchGpt : public GreedySearchBase { init_greedy_state_func_(init_greedy_state_func), update_feeds_func_(update_feeds_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const void* cuda_device_prop, @@ -109,7 +109,7 @@ class GreedySearchGpt : public GreedySearchBase { GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; @@ -336,7 +336,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ // Increase sequence length after a new token is generated. ++current_length; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) // contains DecoderMaskedSelfAttention nodes if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 101b059848908..7545298528453 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -139,7 +139,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -163,7 +163,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { init_greedy_state_fp16_func_, device_copy_func_, update_gpt_feeds_fp16_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index 8f048921a68e0..73425390c6b79 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -57,7 +57,7 @@ class Sampling : public IControlFlowKernel { init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) { reorder_past_state_func_ = reorder_past_state_func; } @@ -70,7 +70,7 @@ class Sampling : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* gpu_device_prop_ = nullptr; int gpu_device_arch_ = 0; #endif @@ -87,7 +87,7 @@ class Sampling : public IControlFlowKernel { GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7bc0f99081169..2de8189450df5 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -29,6 +29,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -52,6 +60,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); @@ -61,12 +73,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); @@ -113,6 +124,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); @@ -162,70 +184,73 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -238,7 +263,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -249,16 +273,25 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo - + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 429ceb1f7c699..5f966ac746fcb 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -2,8 +2,6 @@ // Licensed under the MIT License. #pragma once -#include -#include #include #include #include @@ -294,6 +292,14 @@ __device__ __inline__ T _Gelu(T a) { return a * _Normcdf(a); } +template <> +__device__ __inline__ half _Gelu(half a) { + const half kHalf = half(0.5); + const half kOne = half(1.0); + const half kAlpha = half(M_SQRT1_2); + return a * kHalf * (kOne + _Erf(kAlpha * a)); +} + template __device__ __inline__ T _Mod(T a, T b) { T r = a % b; @@ -348,21 +354,19 @@ struct GridDim { }; }; -// aligned vector generates vectorized load/store +// aligned vector generates vectorized load/store on ROCM template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; -#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ +#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; \ - if (id >= N) \ + if (id >= N) \ return; // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. -// TODO ROCM added support recently, should verify. -#define HIP_KERNEL_ASSERT(...) -// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index 4df7e0b5a5e3b..d130758bec084 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -68,7 +68,7 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorHalf<<>>(x, incx, y, incy, n); + CopyVectorHalf<<>>(x, incx, y, incy, n); return rocblas_status_success; } @@ -76,6 +76,6 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorBFloat16<<>>(x, incx, y, incy, n); + CopyVectorBFloat16<<>>(x, incx, y, incy, n); return rocblas_status_success; } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index fd45ad675ac3e..635a25480b646 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -2,14 +2,15 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" + #include "core/providers/rocm/gpu_data_transfer.h" +#include "core/providers/rocm/rocm_common.h" -// use default stream for copy for now, to avoid racing in BFC arena as in issue #4829 -// note this may cause some models to run slower if there are ops running on CPU -// so we leave it as optional, in case user need the previous behavior -// a full fix to BFC arena is being looked at, and once it's in, we can revert this change namespace onnxruntime { +GPUDataTransfer::GPUDataTransfer() {} + +GPUDataTransfer::~GPUDataTransfer() {} + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; @@ -34,12 +35,12 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory memcpy(dst_data, src_data, bytes); @@ -57,34 +58,29 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (src_device.Type() == OrtDevice::CPU) { // copy from pinned memory to GPU, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (dst_device.Type() == OrtDevice::CPU) { // copying from GPU to pinned memory, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } else { - // copying from GPU to CPU memory, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d35ed52fff5c..3d297bdce4a93 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer() = default; - ~GPUDataTransfer() = default; + GPUDataTransfer(); + ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc index 3c82a436d74e0..9771f42fd3637 100644 --- a/onnxruntime/core/providers/rocm/integer_gemm.cc +++ b/onnxruntime/core/providers/rocm/integer_gemm.cc @@ -5,13 +5,14 @@ #include #include "core/providers/rocm/shared_inc/integer_gemm.h" +#include "core/common/safeint.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/rocm_call.h" namespace onnxruntime { namespace rocm { -inline int roundoff(int v, int d) { +constexpr int roundoff(int v, int d) { return (v + d - 1) / d * d; } @@ -21,20 +22,21 @@ Status GemmInt8(int m, int n, int k, const RocmKernel* rocm_kernel, onnxruntime::Stream* ort_stream) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null"); + ORT_ENFORCE(ort_stream != nullptr, "Rocm kernel must have the stream instance"); - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + hipStream_t stream = static_cast(ort_stream->GetHandle()); // pad A and B to make their leading dimension be multiples of 32 - // because cublasGemmEx requires: + // because rocblas_gemm_ex requires: // 1. leading dimension is multiples of 4 // 2. A, B is 32-bit aligned - const int mask = 0x1F; + constexpr int mask = 0x1F; int lda_aligned = lda; IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = rocm_kernel->GetScratchBuffer(m * lda_aligned, ort_stream); + a_padded = rocm_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream)); } @@ -42,14 +44,15 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = rocm_kernel->GetScratchBuffer(k * ldb_aligned, ort_stream); + b_padded = rocm_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream)); } - RocmStream* ort_rocm_stream = static_cast(ort_stream); - auto handle = ort_rocm_stream->rocblas_handle_; + auto* ort_rocm_stream = dynamic_cast(ort_stream); + auto rocblas = ort_rocm_stream->rocblas_handle_; + ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex( - handle, + rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &alpha, diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h index a4adc3da98436..6be412348e6dd 100644 --- a/onnxruntime/core/providers/rocm/math/einsum.h +++ b/onnxruntime/core/providers/rocm/math/einsum.h @@ -17,8 +17,7 @@ class Einsum final : public onnxruntime::Einsum { Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method // TODO: Clean up the ROCMExecutionProvider interface to avoid this - rocm_ep_ = const_cast( - static_cast(info.GetExecutionProvider())); + rocm_ep_ = static_cast(info.GetExecutionProvider()); } Status Compute(OpKernelContext* context) const override; @@ -32,7 +31,7 @@ class Einsum final : public onnxruntime::Einsum { using onnxruntime::Einsum::equation_; // We need to access to the ROCM EP instance to get the rocblas/miopen handles - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; }; } // namespace rocm diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h index 623bb1d590a27..e1fc3f40ee9a5 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h @@ -21,19 +21,18 @@ namespace EinsumOp { // Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow struct EinsumRocmAssets { explicit EinsumRocmAssets(rocblas_handle rocblas_handle, - ROCMExecutionProvider* rocm_ep, - Stream* ort_stream, - AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), - rocm_ep_(rocm_ep), - ort_stream_(ort_stream), - gpu_allocator_(gpu_allocator) {} + const ROCMExecutionProvider* rocm_ep, + Stream* ort_stream, AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), + rocm_ep_(rocm_ep), + ort_stream_(ort_stream), + gpu_allocator_(gpu_allocator) {} hipStream_t GetRocmStream() { return ort_stream_ ? static_cast(ort_stream_->GetHandle()) : nullptr; } rocblas_handle rocblas_handle_; - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; Stream* ort_stream_; AllocatorPtr gpu_allocator_; }; diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 5a07737d92a02..bae1c000ddfcc 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace rocm { -template +template Status SoftMaxComputeHelper( Stream* stream, const T* X, @@ -29,20 +29,23 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - return dispatch_warpwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + return dispatch_warpwise_softmax_forward< + HipT_IN, HipT_OUT, AccumulationType_t, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); } - return dispatch_blockwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + + return dispatch_blockwise_softmax_forward, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N), tuning_ctx); } -#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); +#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) SPECIALIZED_SOFTMAX_HELPER_IMPL(float, float) diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h index 49bfddad36b43..0a5571bd57dde 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -11,7 +11,7 @@ namespace rocm { using tunable::RocmTuningContext; -template +template Status SoftMaxComputeHelper( Stream* stream, const T* input, @@ -20,14 +20,14 @@ Status SoftMaxComputeHelper( int64_t axis, RocmTuningContext* tuning_ctx = nullptr); -template -Status dispatch_warpwise_softmax_forward(Stream* stream, OutputT* dst, const InputT* src, int softmax_elements, - int softmax_elements_stride, int batch_count, +template +Status dispatch_warpwise_softmax_forward(Stream* stream, output_t* dst, const input_t* src, + int softmax_elements, int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx = nullptr); -template -Status dispatch_blockwise_softmax_forward(Stream* stream, OutputT* output, const InputT* input, int softmax_elements, - int input_stride, int output_stride, int batch_count, +template +Status dispatch_blockwise_softmax_forward(Stream* stream, output_t* output, const input_t* input, + int softmax_elements, int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx = nullptr); template diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6846813c7cb48..6214ec7bc0ea3 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -44,14 +44,13 @@ const miopenConvFwdAlgorithm_t Conv::kAllAlgos[] = { miopenConvolutionFwdAlgoWinograd, miopenConvolutionFwdAlgoImplicitGEMM}; -miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - miopenConvFwdAlgorithm_t algo, size_t* sz) { +miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, miopenConvFwdAlgorithm_t algo, size_t* sz) { return miopenConvolutionForwardGetWorkSpaceSize(handle, s.w_desc, s.x_tensor, s.conv_desc, s.y_tensor, sz); } size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, const miopenConvFwdAlgorithm_t* algo, int n_algo) { - // TODO: get maximum available size from memory arean + // TODO: get maximum available size from memory arena size_t free, total; HIP_CALL_THROW(hipMemGetInfo(&free, &total)); // Assuming 10% of fragmentation @@ -68,8 +67,7 @@ size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& input_dims, + const void* input_data, gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, @@ -103,8 +101,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); } // set B @@ -140,7 +137,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const size_t kernel_rank = kernel_shape.size(); - ConvAttributes::ConvPadVector pads(conv_attrs_.pads); + ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_rank * 2, 0); } @@ -174,7 +171,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) TensorShapeVector slice_axes; slice_axes.reserve(kernel_rank); - const size_t spatial_dim_start = channels_last ? 1 : 2; + constexpr size_t spatial_dim_start = channels_last ? 1 : 2; const size_t spatial_dim_end = spatial_dim_start + kernel_rank; TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); @@ -183,7 +180,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, post_slicing_required, slice_starts, slice_ends, slice_axes, channels_last)); - if (channels_last) { y_dims.push_back(M); y_dims_with_adjusted_pads.push_back(M); @@ -198,9 +194,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.slice_axes = slice_axes; s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } if (post_slicing_required) { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. s_.memory_for_miopen_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); @@ -225,18 +218,23 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } if (w_dims_changed) { - if (channels_last) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); + } else { ORT_RETURN_IF_ERROR(s_.w_desc.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, w_dims[0], w_dims[3], w_dims[1], w_dims[2])); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); } } + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (channels_last) { ORT_RETURN_IF_ERROR(s_.x_tensor.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, @@ -357,7 +355,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions // This may have lead to extra results that are unnecessary and hence we slice that off here if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size)); } @@ -384,18 +382,18 @@ MiopenConvolutionDescriptor::~MiopenConvolutionDescriptor() { Status MiopenConvolutionDescriptor::Set( size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type) { if (!desc_) MIOPEN_RETURN_IF_ERROR(miopenCreateConvolutionDescriptor(&desc_)); - InlinedVector pad_dims(rank); - InlinedVector stride_dims(rank); - InlinedVector dilation_dims(rank); + InlinedVector pad_dims(rank); + InlinedVector stride_dims(rank); + InlinedVector dilation_dims(rank); for (size_t i = 0; i < rank; i++) { pad_dims[i] = gsl::narrow_cast(pads[i]); stride_dims[i] = gsl::narrow_cast(strides[i]); diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index f4f2331e9197e..bc9846203e57d 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -10,6 +10,9 @@ #include namespace onnxruntime { + +using ConvPadVector = ConvAttributes::ConvPadVector; + namespace rocm { class MiopenConvolutionDescriptor final { @@ -18,9 +21,9 @@ class MiopenConvolutionDescriptor final { ~MiopenConvolutionDescriptor(); Status Set(size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type); @@ -198,7 +201,7 @@ class Conv : public RocmKernel { Status SliceOutUnwantedOutputSection(hipStream_t stream, const void* input_data, - const gsl::span& input_dims, + gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 475d26d2e306d..23e9faedb1e76 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -93,9 +93,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = gsl::make_span(y_dims); - if (w_dims_changed) { + if (w_dims_changed) ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); - } // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 4f726017d8b14..820745b22f614 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -8,6 +8,9 @@ #include "core/providers/rocm/math/binary_elementwise_ops_impl.h" #include "core/providers/rocm/math/binary_elementwise_ops.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" +#ifdef ENABLE_TRAINING +#include "contrib_ops/cpu/aten_ops/aten_op.h" +#endif using namespace onnxruntime::common; namespace onnxruntime { @@ -100,8 +103,8 @@ namespace rocm { (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// ROCM ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now. -#define REGISTER_KERNEL_TYPED_11(name, T) \ +// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet +#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -110,10 +113,10 @@ namespace rocm { kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 11, \ + 11, 11, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ @@ -166,7 +169,6 @@ Status ReduceKernel::ReduceKernelShared( const auto rank = input_shape.NumDimensions(); auto hip_stream = stream ? static_cast(stream->GetHandle()) : nullptr; - // Block of fast matrix reduction. if (fast_reduction_) { int m{}, n{}; @@ -210,10 +212,8 @@ Status ReduceKernel::ReduceKernelShared( ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, MiopenTensor::GetDataType(), ReduceTensorIndices)); else ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, ReduceTensorIndices)); - const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -444,17 +444,18 @@ template Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, Stream* ort_stream, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, const TensorShape* input_shape_override) { typedef typename ToHipType::MappedType HipT; const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); + hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; int64_t input_count = prepare_reduce_metadata.input_count; int64_t output_count = prepare_reduce_metadata.output_count; auto& output_dims = prepare_reduce_metadata.output_dims; auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // special case when there is a dim value of 0 in the shape. if (input_count == 0) { assert(output.Shape().Size() == 0); @@ -540,7 +541,6 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -588,11 +588,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); auto indices_rocm_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitRocmNotificationOnDevice); + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, indices_rocm_max.get(), indices_bytes_max, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } // Exp(X-ReduceMax) @@ -652,11 +653,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (input_count == output_count) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(reinterpret_cast(output.MutableData()), input_data, input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, input_data, - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } else { // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case @@ -675,11 +677,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, Impl_Cast(stream, temp_output.get(), reinterpret_cast(output.MutableData()), output_count); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } } @@ -743,18 +746,29 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR // empty axes and no-op if (axes.empty() && noop_with_empty_axes_) { auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream(ctx))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), + hipMemcpyDeviceToDevice, Stream(ctx))); return Status::OK(); } +#ifdef ENABLE_TRAINING + // Use ATen for ReduceSum if possible. + const TensorShape& input_shape = X->Shape(); + if (contrib::IsATenOperatorExecutorInitialized() && miopen_reduce_op == MIOPEN_REDUCE_TENSOR_ADD && !calculate_log_ && + !calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) { + if (axes.empty()) { + axes.resize(input_shape.NumDimensions()); + std::iota(axes.begin(), axes.end(), 0); + } + ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATen(ctx, axes, keepdims_)); + return Status::OK(); + } +#endif + PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); + ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, miopen_reduce_op, axes, calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); } @@ -837,7 +851,6 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(GetMiopenHandle(ctx), reduce_desc, indices_rocm.get(), indices_bytes, \ workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ &zero, output_tensor, temp_Y.get())); \ - \ Impl_Cast(Stream(ctx), temp_Y.get(), reinterpret_cast(Y->MutableData()), output_count); \ \ return Status::OK(); \ @@ -909,13 +922,13 @@ template std::unique_ptr ReduceCompute #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/platform/ort_mutex.h" @@ -56,7 +55,7 @@ class ROCMPinnedAllocator : public IAllocator { ROCMPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device always with id 0*/), 0, OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 730f55608c725..484e59f4de7d8 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -39,11 +39,11 @@ const char* RocmErrString(rocblas_status e) { CASE_ENUM_TO_STR(rocblas_status_invalid_handle); CASE_ENUM_TO_STR(rocblas_status_not_implemented); CASE_ENUM_TO_STR(rocblas_status_invalid_pointer); + CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_invalid_size); CASE_ENUM_TO_STR(rocblas_status_memory_error); CASE_ENUM_TO_STR(rocblas_status_internal_error); CASE_ENUM_TO_STR(rocblas_status_perf_degraded); - CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_size_increased); CASE_ENUM_TO_STR(rocblas_status_size_unchanged); CASE_ENUM_TO_STR(rocblas_status_invalid_value); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index c9975d0bc76c0..3c106313a89cd 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/rocm/rocm_execution_provider.h" @@ -9,7 +10,6 @@ #include "core/providers/rocm/rocm_fwd.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_profiler.h" -#include "core/providers/rocm/rocm_stream_handle.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/rocm/rocm_contrib_kernels.h" @@ -23,6 +23,8 @@ #include "core/providers/rocm/triton_kernel.h" #endif +#include "core/providers/rocm/rocm_stream_handle.h" + using namespace onnxruntime::common; namespace onnxruntime { @@ -38,42 +40,64 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); - return gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream()); - } else if (X_type->IsSparseTensorType()) { - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); - return X->Copy(Info().GetDataTransferManager(), *Y); - } else if (X_type->IsTensorSequenceType()) { - const TensorSeq* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); - TensorSeq* Y = ctx->Output(0); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); - auto X_dtype = X->DataType(); - Y->SetType(X_dtype); - AllocatorPtr alloc; - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy rocm: unable to get an allocator."); - } - auto X_size = X->Size(); - Y->Reserve(X_size); - for (size_t i = 0; i < X_size; ++i) { - const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - Status retval = gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream()); - if (!retval.IsOK()) { - return retval; + // do we support async copy? + // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, + // so we don't need the check here. + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); + return Status::OK(); + } else { + if (X_type->IsSparseTensorType()) { + // TODO: support aysnc copy for sparse tensor + // sync the stream first, since it is a sync memory copy + HIP_CALL_THROW(hipStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()))); + const auto* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); + return X->Copy(Info().GetDataTransferManager(), *Y); + } else if (X_type->IsTensorSequenceType()) { + const TensorSeq* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); + TensorSeq* Y = ctx->Output(0); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); + auto X_dtype = X->DataType(); + Y->SetType(X_dtype); + AllocatorPtr alloc; + + // If we are copying contents to ROCM, the allocator to use + // to allocate the buffers of the new tensors in the sequence + // can be temp space allocator associated with the ROCM EP + if (Node().OpType() == "MemcpyFromHost") { + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get an allocator."); + } + } else { + // If we are copying contents to CPU (op type is "MemcpyToHost"), + // the allocator to use to allocate the buffers of the new tensors + // in the sequence will be the allocator from the CPU EP + auto status = ctx->GetTempSpaceCPUAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get the CPU allocator."); + } + } + auto X_size = X->Size(); + Y->Reserve(X_size); + for (size_t i = 0; i < X_size; ++i) { + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, + target_tensor->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); + Y->Add(std::move(*target_tensor)); } - Y->Add(std::move(*target_tensor)); + return Status::OK(); } - return Status::OK(); + return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } - return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } }; @@ -100,18 +124,23 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace rocm -AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) { +AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, + size_t gpu_mem_limit, + ArenaExtendStrategy arena_extend_strategy, + ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, + const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, external_allocator_info.alloc, external_allocator_info.free, external_allocator_info.empty_cache); + return std::make_unique(id, HIP, + external_allocator_info.alloc, + external_allocator_info.free, + external_allocator_info.empty_cache); }, device_id, false); return CreateAllocator(default_memory_info); - } else { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId id) { @@ -120,12 +149,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, - static_cast(arena_extend_strategy), - -1, - -1, - -1, - -1)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -149,20 +173,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { - // dtor shouldn't throw. if something went wrong earlier (e.g. out of ROCM memory) the handles - // here may be bad, and the destroy calls can throw. - // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-noexcept - try { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "rocblas_destroy_handle threw:" << ex.what(); - } - - try { - MIOPEN_CALL_THROW(miopenDestroy(miopen_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "miopenDestroy threw:" << ex.what(); - } + ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_))); + ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { @@ -235,7 +247,7 @@ ROCMExecutionProvider::~ROCMExecutionProvider() { } if (!external_stream_ && stream_) { - HIP_CALL_THROW(hipStreamDestroy(stream_)); + ORT_IGNORE_RETURN_VALUE(HIP_CALL(hipStreamDestroy(stream_))); } } @@ -315,7 +327,7 @@ Status ROCMExecutionProvider::OnRunStart() { Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream) { - HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), @@ -716,12 +728,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -774,7 +786,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 18, Scan); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Slice); @@ -827,12 +839,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot); @@ -1087,6 +1097,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -1105,17 +1126,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -1186,12 +1196,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, 18, Shape); // Opset 16 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LeakyRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); @@ -1260,977 +1271,929 @@ KernelCreateInfo BuildKernelCreateInfo() { static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 10 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 12 - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 13 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 14 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 15 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 16 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 17 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 18 - BuildKernelCreateInfo, - - // Opset 19 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // opset 10 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // opset 11 + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 12 + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 13 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 14 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 16 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 17 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 18 + BuildKernelCreateInfo, + + // Opset 19 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -2336,7 +2299,6 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // These are usually shape related computation subgraphs // Following logic can be extended for other EPs auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); - std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) @@ -2371,7 +2333,7 @@ OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) cons std::vector ROCMExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId device_id) { + [](OrtDevice::DeviceId) { return std::make_unique(HIP_PINNED); }, // TODO: should we use info_.device_id instead of DEFAULT_CPU_ALLOCATOR_DEVICE_ID? @@ -2383,7 +2345,8 @@ std::vector ROCMExecutionProvider::CreatePreferredAllocators() { return std::vector{ CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg), - CreateAllocator(pinned_memory_info)}; + CreateAllocator(pinned_memory_info), + }; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 3e86afb7d643c..c4945b9ac2481 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -36,11 +36,11 @@ class ROCMExecutionProvider : public IExecutionProvider { return nullptr; } - rocblas_handle PerThreadRocblasHandle() { + rocblas_handle PerThreadDefaultRocblasHandle() { return GetPerThreadContext().RocblasHandle(); } - miopenHandle_t PerThreadMiopenHandle() { + miopenHandle_t PerThreadDefaultMiopenHandle() { return GetPerThreadContext().MiopenHandle(); } @@ -60,7 +60,6 @@ class ROCMExecutionProvider : public IExecutionProvider { const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; int GetMiopenConvExhaustiveSearch() const { return info_.miopen_conv_exhaustive_search; } bool DoCopyOnDefaultStream() const { return info_.do_copy_in_default_stream; } - bool GetMiopenConvUseMaxWorkspace() const { return info_.miopen_conv_use_max_workspace; } ProviderOptions GetProviderOptions() const override { @@ -68,15 +67,15 @@ class ROCMExecutionProvider : public IExecutionProvider { } static AllocatorPtr CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); + ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); ITuningContext* GetTuningContext() const override; std::unique_ptr GetProfiler() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - std::vector CreatePreferredAllocators() override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + std::vector CreatePreferredAllocators() override; private: ROCMExecutionProviderInfo info_; @@ -105,21 +104,30 @@ class ROCMExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, hipStream_t stream) { - if (std::is_same::value) { + constexpr bool is_float = std::is_same::value; + constexpr bool is_double = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_BFloat16 = std::is_same::value; + if (is_float) { if (!constant_ones_float_) { constant_ones_float_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_double) { if (!constant_ones_double_) { constant_ones_double_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_half) { if (!constant_ones_half_) { constant_ones_half_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); + } else if (is_BFloat16) { + if (!constant_ones_bfloat16_) { + constant_ones_bfloat16_ = rocm::CreateConstantOnes(); + } + return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); } else { return nullptr; } @@ -132,6 +140,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_float_; std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; + std::unique_ptr> constant_ones_bfloat16_; }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 91e3aaaa4280f..650635c153640 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -27,12 +27,10 @@ constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_dur } // namespace provider_option_names } // namespace rocm -namespace { const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -} // namespace ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { ROCMExecutionProviderInfo info{}; @@ -81,7 +79,9 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToEnumReference( rocm::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) - .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvExhaustiveSearch, info.miopen_conv_exhaustive_search) + .AddAssignmentToReference( + rocm::provider_option_names::kMiopenConvExhaustiveSearch, + info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) .AddValueParser( diff --git a/onnxruntime/core/providers/rocm/rocm_fwd.h b/onnxruntime/core/providers/rocm/rocm_fwd.h deleted file mode 100644 index b123446fa9be1..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_fwd.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace rocm { -template -KernelCreateInfo BuildKernelCreateInfo(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 02f15fdad8b77..463c1cf0d2ea6 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -35,14 +35,12 @@ class RocmKernel : public OpKernel { // use this to precisely locate the node where ROCM failure comes from // if (hipSuccess != hipDeviceSynchronize()) // __debugbreak(); - if (s.IsOK()) { auto err = hipGetLastError(); if (err != hipSuccess) { - s = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); } } - return s; } @@ -64,18 +62,18 @@ class RocmKernel : public OpKernel { return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, true); } - template - inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); - } - inline void AddDeferredReleaseCPUPtr(void* p, onnxruntime::Stream* ort_stream) const { ORT_ENFORCE(ort_stream->GetDevice().Type() == OrtDevice::GPU); auto* rocm_ep_stream = static_cast(ort_stream); rocm_ep_stream->EnqueDeferredCPUBuffer(p); } + template + inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { + if (count_or_bytes == 0) return nullptr; + return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); + } + const hipDeviceProp_t& GetDeviceProp() const { return provider_->GetDeviceProp(); } inline hipStream_t Stream(OpKernelContext* ctx) const { @@ -83,6 +81,22 @@ class RocmKernel : public OpKernel { return stream ? static_cast(stream->GetHandle()) : nullptr; } + inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { + return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + } + + static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { + return stream->miopen_handle_; + } + + inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { + return GetRocblasHandle(static_cast(ctx->GetComputeStream())); + } + + static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { + return stream->rocblas_handle_; + } + tunable::RocmTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } @@ -106,7 +120,7 @@ class RocmKernel : public OpKernel { } } - RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { + RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); } @@ -151,28 +165,12 @@ class RocmKernel : public OpKernel { const RocmKernel* op_kernel_; }; - inline rocblas_handle RocblasHandle() const { - return provider_->PerThreadRocblasHandle(); - } - - inline miopenHandle_t MiopenHandle() const { - return provider_->PerThreadMiopenHandle(); + inline rocblas_handle DefaultRocblasHandle() const { + return provider_->PerThreadDefaultRocblasHandle(); } - static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { - return stream->rocblas_handle_; - } - - inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { - return GetRocblasHandle(static_cast(ctx->GetComputeStream())); - } - - static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { - return stream->miopen_handle_; - } - - inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { - return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + inline miopenHandle_t DefaultMiopenHandle() const { + return provider_->PerThreadDefaultMiopenHandle(); } protected: diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index e55b2edbad685..f59e609b8f355 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -3,15 +3,13 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/rocm/rocm_provider_factory.h" - -#include +#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/common/gsl.h" #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" #include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" @@ -47,7 +45,7 @@ std::unique_ptr ROCMProviderFactory::CreateProvider() { return std::make_unique(info_); } -struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { +struct ProviderInfo_ROCM_Impl final : ProviderInfo_ROCM { OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) override { int num_devices; auto hip_err = ::hipGetDeviceCount(&num_devices); @@ -128,9 +126,26 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { } // Used by slice_concatenate_test.cc and onnxruntime_pybind_state.cc - void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); } + + void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { + // hipMemcpy() operates on the default stream + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); + + // To ensure that the copy has completed, invoke a stream sync for the default stream. + // https://docs.nvidia.com/rocm/rocm-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync + // For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. + // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer + // to device memory, but the DMA to final destination may not have completed. + + HIP_CALL_THROW(hipStreamSynchronize(0)); + } + // Used by onnxruntime_pybind_state.cc - void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } + void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { + // https://docs.nvidia.com/rocm/rocm-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync + // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); + } int hipGetDeviceCount() override { int num_devices = 0; @@ -152,10 +167,9 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { return std::make_shared(info); } - std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) override { + std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { return ROCMExecutionProvider::CreateRocmAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } - } g_info; struct ROCM_Provider : Provider { @@ -169,8 +183,8 @@ struct ROCM_Provider : Provider { info.gpu_mem_limit = params->gpu_mem_limit; info.arena_extend_strategy = static_cast(params->arena_extend_strategy); info.miopen_conv_exhaustive_search = params->miopen_conv_exhaustive_search; - info.do_copy_in_default_stream = params->do_copy_in_default_stream; - info.has_user_compute_stream = params->has_user_compute_stream; + info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0; + info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; info.tunable_op.enable = params->tunable_op_enable; @@ -180,21 +194,32 @@ struct ROCM_Provider : Provider { return std::make_shared(info); } + /** + * This function will be called by the C API UpdateROCMProviderOptions(). + * + * What this function does is equivalent to resetting the OrtROCMProviderOptions instance with + * default ROCMExecutionProviderInf instance first and then set up the provided provider options. + * See ROCMExecutionProviderInfo::FromProviderOptions() for more details. + */ void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto info = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); + auto internal_options = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); auto& rocm_options = *reinterpret_cast(provider_options); - rocm_options.device_id = info.device_id; - rocm_options.gpu_mem_limit = info.gpu_mem_limit; - rocm_options.arena_extend_strategy = static_cast(info.arena_extend_strategy); - rocm_options.miopen_conv_exhaustive_search = info.miopen_conv_exhaustive_search; - rocm_options.do_copy_in_default_stream = info.do_copy_in_default_stream; - rocm_options.has_user_compute_stream = info.has_user_compute_stream; - rocm_options.user_compute_stream = info.user_compute_stream; - rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg; - rocm_options.tunable_op_enable = info.tunable_op.enable; - rocm_options.tunable_op_tuning_enable = info.tunable_op.tuning_enable; - rocm_options.tunable_op_max_tuning_duration_ms = info.tunable_op.max_tuning_duration_ms; + rocm_options.device_id = internal_options.device_id; + rocm_options.gpu_mem_limit = internal_options.gpu_mem_limit; + rocm_options.arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + rocm_options.miopen_conv_exhaustive_search = internal_options.miopen_conv_exhaustive_search; + rocm_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; + rocm_options.has_user_compute_stream = internal_options.has_user_compute_stream; + // The 'has_user_compute_stream' of the OrtROCMProviderOptions instance can be set byC API UpdateROCMProviderOptionsWithValue() as well. + // We only set the 'has_user_compute_stream' of the OrtROCMProviderOptions instance if it is provided in options + if (options.find("has_user_compute_stream") != options.end()) { + rocm_options.user_compute_stream = internal_options.user_compute_stream; + } + rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.tunable_op_enable = internal_options.tunable_op.enable; + rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; + rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; } ProviderOptions GetProviderOptions(const void* provider_options) override { @@ -219,4 +244,6 @@ extern "C" { ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } + } + diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.h b/onnxruntime/core/providers/rocm/rocm_provider_factory.h index 8cd7bd357330f..80b887af4eb75 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.h +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.h @@ -3,6 +3,7 @@ #include "onnxruntime_c_api.h" #include "core/framework/provider_options.h" +#include "core/common/common.h" namespace onnxruntime { class IAllocator; @@ -43,7 +44,16 @@ struct ProviderInfo_ROCM { #endif virtual std::shared_ptr CreateExecutionProviderFactory(const onnxruntime::ROCMExecutionProviderInfo& info) = 0; - virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + + // This function is the entry point to ROCM EP's UT cases. + // All tests ared only called from onnxruntime_test_all. + virtual void TestAll() { + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is only implements in test code path."); + } + + protected: + ~ProviderInfo_ROCM() = default; // Can only be destroyed through a subclass instance }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index 0d9877e6b18e6..670aae91ca710 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -1,7 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/rocm/rocm_resource.h" #include "core/providers/rocm/rocm_stream_handle.h" #include "core/providers/rocm/rocm_common.h" // #include "core/common/spin_pause.h" -#include "core/providers/rocm/rocm_resource.h" namespace onnxruntime { @@ -82,15 +84,29 @@ void RocmStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { deferred_cpu_buffers_.push_back(cpu_buffer); } -struct CpuBuffersInfo { // TODO: should be moved to base class +struct CpuBuffersInfo { + // This struct stores the information needed + // to release CPU buffers allocated for GPU kernels. + // It's used to enqueue their release after + // associated GPU kernels in a ROCM stream. + + // This is a CPU allocator in ROCM EP. + // It must be the one used to allocate the + // following pointers. AllocatorPtr allocator; + // buffers[i] is the i-th pointer added by + // AddDeferredReleaseCPUPtr for a specific + // ROCM stream. For example, this fields + // should contain all values in + // deferred_release_buffer_pool_[my_stream] + // when release my_stream's buffers. std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". size_t n_buffers; }; -static void ReleaseCpuBufferCallback(hipStream_t /*stream*/, hipError_t /*status*/, void* raw_info) { // TODO: should be moved to base class +static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); info.reset(reinterpret_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { @@ -111,14 +127,7 @@ Status RocmStream::CleanUpOnRunEnd() { cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); } cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); - // TODO(wechi): CUDA deprecates cudaStreamAddCallback and - // uses another API, cudaLaunchHostFunc(which can be - // captured in CUDA graph). Once AMD adds similar feature, - // we should replace the following line with - // hipLaunchHostFunc(stream, ReleaseCpuBufferCallback, cpu_buffers_info); - - // Release memory asynchronously to avoid blocking the compute stream. - HIP_RETURN_IF_ERROR(hipStreamAddCallback(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release(), 0)); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); } else { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); for (auto* buffer : deferred_cpu_buffers_) { @@ -130,10 +139,10 @@ Status RocmStream::CleanUpOnRunEnd() { return Status::OK(); } -void* RocmStream::GetResource(int version, int type) const { +void* RocmStream::GetResource(int version, int id) const { ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!"); void* resource{}; - switch (type) { + switch (id) { case RocmResource::hip_stream_t: return reinterpret_cast(GetHandle()); break; @@ -149,6 +158,7 @@ void* RocmStream::GetResource(int version, int type) const { return resource; } +// CPU Stream command handles void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_device(stream); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 865cff0abf85f..1f3e5b75548e7 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/providers/rocm/rocm_pch.h" // #include "core/providers/cuda/shared_inc/cuda_utils.h" @@ -17,14 +20,12 @@ struct RocmStream : Stream { ~RocmStream(); - std::unique_ptr CreateNotification(size_t num_consumers) override; + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; void Flush() override; Status CleanUpOnRunEnd() override; - void* GetResource(int version, int id) const override; - void EnqueDeferredCPUBuffer(void* cpu_buffer); bool own_stream_{true}; @@ -33,6 +34,8 @@ struct RocmStream : Stream { rocblas_handle rocblas_handle_{}; + void* GetResource(int version, int id) const override; + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/rocm/rocm_utils.cu b/onnxruntime/core/providers/rocm/rocm_utils.cu index cbf410e78a4a9..b817e025cedf4 100644 --- a/onnxruntime/core/providers/rocm/rocm_utils.cu +++ b/onnxruntime/core/providers/rocm/rocm_utils.cu @@ -30,13 +30,14 @@ template void Fill(hipStream_t stream, T* output, T value, int64_t count) { int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); HIP_LONG N = static_cast(count); - _Fill<<>>(output, value, N); + _Fill + <<>>(output, value, N); } template class ConstantBufferImpl : public IConstantBuffer { public: - ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) {} - + ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) { + } ~ConstantBufferImpl() { if (buffer_) HIP_CALL_THROW(hipFree(buffer_)); @@ -70,6 +71,7 @@ std::unique_ptr> CreateConstantOnes() { template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); +template std::unique_ptr> CreateConstantOnes(); #define SPECIALIZED_FILL(T) \ template void Fill(hipStream_t stream, T * output, T value, int64_t count); @@ -81,6 +83,7 @@ SPECIALIZED_FILL(int64_t) SPECIALIZED_FILL(float) SPECIALIZED_FILL(double) SPECIALIZED_FILL(__half) +SPECIALIZED_FILL(BFloat16) } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h b/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h deleted file mode 100644 index 83ca0a443c4fa..0000000000000 --- a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h +++ /dev/null @@ -1,90 +0,0 @@ -// -// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved -// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. -// - -#pragma once - -#include -#include -#include -#include -#include "core/common/common.h" - -namespace onnxruntime { -namespace rocm { - -// DivMod is a helper class for integer division and modulo operation. -// There is a fast version for int type and a slow version for other type. -template -struct DivMod { - DivMod(T d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= std::numeric_limits::max()); - } - - __host__ __device__ inline T div(T n) const { - return n / d_; - } - - __host__ __device__ inline T mod(T n) const { - return n % d_; - } - - __host__ __device__ inline void divmod(T n, T& q, T& r) const { - q = div(n); - r = n - q * d_; - } - - T d_; // divisor -}; - -// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf -// In current ORT, fast_divmod is used for calculating the position of a element in tensor, -// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple, -// then GPU compiler can do loop unroll easilly when divmod is called in a loop. -template <> -struct DivMod { - DivMod(int d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= static_cast(std::numeric_limits::max())); - - for (l_ = 0; l_ < 32; l_++) - if ((1U << l_) >= d_) break; - - uint64_t one = 1; - uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; - M_ = static_cast(m); - // according to paper, the value of m' should fit in a unsigned integer. - ORT_ENFORCE(M_ > 0 && M_ == m); - } - - __host__ __device__ inline int div(int n) const { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - uint32_t t = __umulhi(M_, n); - return (t + n) >> l_; -#else - // Using uint64_t for t, then t + n won't overflow. - uint64_t t = ((uint64_t)M_ * n) >> 32; - return static_cast((t + n) >> l_); -#endif - } - - __host__ __device__ inline int mod(int n) const { - return n - div(n) * d_; - } - - __host__ __device__ inline void divmod(int n, int& q, int& r) const { - q = div(n); - r = n - q * d_; - } - - uint32_t d_; // divisor - uint32_t M_; // m' in the paper. - uint32_t l_; // l_ = ceil(log2(d_)) -}; - -using fast_divmod = DivMod; // Keep the old name for backward compatibility. - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index d6623ef63f0fd..b6b40666b8bd0 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,16 +17,20 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define MIOPEN_CALL(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) + #define HIPFFT_CALL(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) + #define MIOPEN_CALL_THROW(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL_THROW2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) #define HIPFFT_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 15e2449cd230f..fff103e5fa339 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -197,9 +197,8 @@ TEST(BiasGeluTest, BFloat16) { } #endif +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, ComplexMul) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -219,13 +218,15 @@ TEST(MathOpTest, ComplexMul) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -245,13 +246,15 @@ TEST(MathOpTest, ComplexMulConj) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMul_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -271,13 +274,15 @@ TEST(MathOpTest, ComplexMul_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -297,9 +302,14 @@ TEST(MathOpTest, ComplexMulConj_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc index eaadb95c8a0c0..7092f41da3443 100644 --- a/onnxruntime/test/contrib_ops/fft_op_test.cc +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -8,7 +8,14 @@ namespace onnxruntime { namespace test { TEST(ContribOpTest, Rfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#else + return; +#endif OpTester test("Rfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -17,13 +24,18 @@ TEST(ContribOpTest, Rfft) { // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward") test.AddInput("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); test.AddOutput("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(ContribOpTest, Irfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#else + return; +#endif OpTester test("Irfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -31,8 +43,6 @@ TEST(ContribOpTest, Irfft) { test.AddAttribute("normalized", static_cast(0)); test.AddInput("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); test.AddOutput("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index f5259c1391f38..cf643fdf73753 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -50,11 +50,16 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; +#ifdef USE_CUDA constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { + if (!HasCudaEnvironment(min_cuda_architecture)) return; +#endif + { Ort::SessionOptions session_options; #ifdef USE_CUDA Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#elif USE_ROCM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); #endif // The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx @@ -117,11 +122,16 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; +#ifdef USE_CUDA constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { + if (!HasCudaEnvironment(min_cuda_architecture)) return; +#endif + { Ort::SessionOptions session_options; #ifdef USE_CUDA Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#elif USE_ROCM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); #endif Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_greedysearch_with_init_decoder.onnx"), session_options); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e0293128045a5..6f492317524be 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -150,6 +150,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): # CUFFT -> HIPFFT s = s.replace("CUFFT", "HIPFFT") + s = s.replace("cufftXtMakePlanMany", "hipfftXtMakePlanMany") + s = s.replace("cufftXtExec", "hipfftXtExec") # Undo where above hipify steps went too far. s = s.replace("id, ROCM", "id, CUDA") # cuda_execution_provider.cc @@ -169,6 +171,24 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") + s = s.replace("#include ", "#include ") + s = s.replace('#include "hipfft.h"', "#include ") + s = s.replace('#include "hipfftXt.h"', "#include ") + + # Fix onnxruntime/contrib_ops/rocm/transformers. They include cpu headers which use "cuda" in their names. + s = s.replace("rocm_device_prop_", "cuda_device_prop_") + s = s.replace("rocm_device_arch_", "cuda_device_arch_") + + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names + # And we do this last, undoing or fixing hipify mistakes. + if "fft" in src_file_path: + s = s.replace("rocblas_datatype", "hipDataType") + s = s.replace("hipDataType_f32_c", "HIP_C_32F") + s = s.replace("hipDataType_f32_r", "HIP_R_32F") + s = s.replace("hipDataType_f64_c", "HIP_C_64F") + s = s.replace("hipDataType_f64_r", "HIP_R_64F") + s = s.replace("hipDataType_f16_c", "HIP_C_16F") + s = s.replace("hipDataType_f16_r", "HIP_R_16F") with open(dst_file_path, "w") as f: f.write(s)