diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 96c05e5282bb5..14a1909cf4e80 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1516,6 +1516,7 @@ if (onnxruntime_USE_ROCM) find_package(hiprand REQUIRED) find_package(rocblas REQUIRED) find_package(MIOpen REQUIRED) + find_package(hipfft REQUIRED) # MIOpen version if(NOT DEFINED ENV{MIOPEN_PATH}) @@ -1554,7 +1555,7 @@ if (onnxruntime_USE_ROCM) find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB}) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cf71b6bcf7c7d..d34b58bced22c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" - "math/complex_mul.cc" - "math/complex_mul.h" - "math/complex_mul_impl.cu" - "math/complex_mul_impl.h" - "math/cufft_plan_cache.h" - "math/fft_ops.cc" - "math/fft_ops.h" - "math/fft_ops_impl.cu" - "math/fft_ops_impl.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files "quantization/qordered_ops/qordered_unary_ops.cc" "quantization/qordered_ops/qordered_unary_ops_impl.h" "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "tensor/crop.cc" - "tensor/crop.h" - "tensor/crop_impl.cu" - "tensor/crop_impl.h" - "tensor/dynamicslice.cc" - "tensor/image_scaler.cc" - "tensor/image_scaler.h" - "tensor/image_scaler_impl.cu" - "tensor/image_scaler_impl.h" - "transformers/greedy_search.cc" - "transformers/greedy_search.h" - "conv_transpose_with_dynamic_pads.cc" - "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" "cuda_contrib_kernels.h" "inverse.cc" @@ -114,10 +92,6 @@ endif() set(provider_excluded_files "atomic/common.cuh" - "controlflow/loop.cc" - "controlflow/loop.h" - "controlflow/scan.cc" - "controlflow/scan.h" "cu_inc/common.cuh" "math/einsum_utils/einsum_auxiliary_ops.cc" "math/einsum_utils/einsum_auxiliary_ops.h" @@ -165,7 +139,6 @@ set(provider_excluded_files "cuda_memory_check.h" "cuda_fence.cc" "cuda_fence.h" - "cuda_fwd.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..3761737994db7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -217,7 +217,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -240,7 +240,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_int32_func_, update_gpt_feeds_fp16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -271,7 +271,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -293,7 +293,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_, expand_buffer_float16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -320,7 +320,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -341,7 +341,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_float_func_, expand_buffer_float16_func_, create_beam_scorer_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..ff2fe875678a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -66,7 +66,7 @@ class BeamSearch : public IControlFlowKernel { create_beam_scorer_func_ = create_beam_scorer_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func) { @@ -96,7 +96,7 @@ class BeamSearch : public IControlFlowKernel { expand_buffer_float16_func_ = expand_buffer_float16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; #endif @@ -115,7 +115,7 @@ class BeamSearch : public IControlFlowKernel { GenerationDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..50417667d887f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -46,7 +46,7 @@ class BeamSearchGpt : public BeamSearchBase { update_feeds_func_(update_feeds_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const void* cuda_device_prop, @@ -100,7 +100,7 @@ class BeamSearchGpt : public BeamSearchBase { GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; @@ -336,7 +336,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch // Increase sequence length after a new token is generated. ++current_length; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) // contains DecoderMaskedSelfAttention nodes if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..b3e80d9482e7f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -53,7 +53,7 @@ class BeamSearchT5 : public BeamSearchBase { expand_buffer_float16_func_(expand_buffer_float16_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, @@ -87,7 +87,7 @@ class BeamSearchT5 : public BeamSearchBase { // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif @@ -280,7 +280,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..b22a0cf44fefa 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -51,7 +51,7 @@ class BeamSearchWhisper : public BeamSearchBase { expand_buffer_float16_func_(expand_buffer_float16_func), create_beam_scorer_func_(create_beam_scorer_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, @@ -85,7 +85,7 @@ class BeamSearchWhisper : public BeamSearchBase { // Device specific functions GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; #endif @@ -272,7 +272,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size(); beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Here we only need to reorder the past key for self-attention and cross-attention. for (size_t i = 0; i < 2 * static_cast(decoder_subgraph_.num_layers); ++i) { ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..90c06d0fd64d4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -33,7 +33,7 @@ enum DeviceCopyDirection { namespace GenerationDeviceHelper { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) using ReorderPastStateFunc = std::function, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -227,7 +227,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { init_greedy_state_fp16_func_, device_copy_func_, update_gpt_feeds_fp16_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h index a065255766e31..d49db65dbcdd8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h @@ -60,7 +60,7 @@ class GreedySearch : public IControlFlowKernel { init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) { reorder_past_state_func_ = reorder_past_state_func; } @@ -73,7 +73,7 @@ class GreedySearch : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; #endif @@ -90,7 +90,7 @@ class GreedySearch : public IControlFlowKernel { GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..8ae948bce85e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -60,7 +60,7 @@ class GreedySearchGpt : public GreedySearchBase { init_greedy_state_func_(init_greedy_state_func), update_feeds_func_(update_feeds_func) {} -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) Status InitializeCuda( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const void* cuda_device_prop, @@ -109,7 +109,7 @@ class GreedySearchGpt : public GreedySearchBase { GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; @@ -336,7 +336,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ // Increase sequence length after a new token is generated. ++current_length; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) // contains DecoderMaskedSelfAttention nodes if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc index 101b059848908..7545298528453 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -139,7 +139,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); @@ -163,7 +163,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const { init_greedy_state_fp16_func_, device_copy_func_, update_gpt_feeds_fp16_func_}; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_)); #endif ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h index 8f048921a68e0..73425390c6b79 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sampling.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -57,7 +57,7 @@ class Sampling : public IControlFlowKernel { init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) { reorder_past_state_func_ = reorder_past_state_func; } @@ -70,7 +70,7 @@ class Sampling : public IControlFlowKernel { update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) const void* gpu_device_prop_ = nullptr; int gpu_device_arch_ = 0; #endif @@ -87,7 +87,7 @@ class Sampling : public IControlFlowKernel { GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; #endif diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7bc0f99081169..2de8189450df5 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -29,6 +29,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -52,6 +60,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); @@ -61,12 +73,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); @@ -113,6 +124,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); @@ -162,70 +184,73 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -238,7 +263,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -249,16 +273,25 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo - + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 429ceb1f7c699..5f966ac746fcb 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -2,8 +2,6 @@ // Licensed under the MIT License. #pragma once -#include -#include #include #include #include @@ -294,6 +292,14 @@ __device__ __inline__ T _Gelu(T a) { return a * _Normcdf(a); } +template <> +__device__ __inline__ half _Gelu(half a) { + const half kHalf = half(0.5); + const half kOne = half(1.0); + const half kAlpha = half(M_SQRT1_2); + return a * kHalf * (kOne + _Erf(kAlpha * a)); +} + template __device__ __inline__ T _Mod(T a, T b) { T r = a % b; @@ -348,21 +354,19 @@ struct GridDim { }; }; -// aligned vector generates vectorized load/store +// aligned vector generates vectorized load/store on ROCM template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; -#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ +#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; \ - if (id >= N) \ + if (id >= N) \ return; // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. -// TODO ROCM added support recently, should verify. -#define HIP_KERNEL_ASSERT(...) -// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index 4df7e0b5a5e3b..d130758bec084 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -68,7 +68,7 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorHalf<<>>(x, incx, y, incy, n); + CopyVectorHalf<<>>(x, incx, y, incy, n); return rocblas_status_success; } @@ -76,6 +76,6 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorBFloat16<<>>(x, incx, y, incy, n); + CopyVectorBFloat16<<>>(x, incx, y, incy, n); return rocblas_status_success; } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index fd45ad675ac3e..635a25480b646 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -2,14 +2,15 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" + #include "core/providers/rocm/gpu_data_transfer.h" +#include "core/providers/rocm/rocm_common.h" -// use default stream for copy for now, to avoid racing in BFC arena as in issue #4829 -// note this may cause some models to run slower if there are ops running on CPU -// so we leave it as optional, in case user need the previous behavior -// a full fix to BFC arena is being looked at, and once it's in, we can revert this change namespace onnxruntime { +GPUDataTransfer::GPUDataTransfer() {} + +GPUDataTransfer::~GPUDataTransfer() {} + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; @@ -34,12 +35,12 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory memcpy(dst_data, src_data, bytes); @@ -57,34 +58,29 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (src_device.Type() == OrtDevice::CPU) { // copy from pinned memory to GPU, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (dst_device.Type() == OrtDevice::CPU) { // copying from GPU to pinned memory, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } else { - // copying from GPU to CPU memory, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d35ed52fff5c..3d297bdce4a93 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer() = default; - ~GPUDataTransfer() = default; + GPUDataTransfer(); + ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc index 3c82a436d74e0..9771f42fd3637 100644 --- a/onnxruntime/core/providers/rocm/integer_gemm.cc +++ b/onnxruntime/core/providers/rocm/integer_gemm.cc @@ -5,13 +5,14 @@ #include #include "core/providers/rocm/shared_inc/integer_gemm.h" +#include "core/common/safeint.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/rocm_call.h" namespace onnxruntime { namespace rocm { -inline int roundoff(int v, int d) { +constexpr int roundoff(int v, int d) { return (v + d - 1) / d * d; } @@ -21,20 +22,21 @@ Status GemmInt8(int m, int n, int k, const RocmKernel* rocm_kernel, onnxruntime::Stream* ort_stream) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null"); + ORT_ENFORCE(ort_stream != nullptr, "Rocm kernel must have the stream instance"); - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + hipStream_t stream = static_cast(ort_stream->GetHandle()); // pad A and B to make their leading dimension be multiples of 32 - // because cublasGemmEx requires: + // because rocblas_gemm_ex requires: // 1. leading dimension is multiples of 4 // 2. A, B is 32-bit aligned - const int mask = 0x1F; + constexpr int mask = 0x1F; int lda_aligned = lda; IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = rocm_kernel->GetScratchBuffer(m * lda_aligned, ort_stream); + a_padded = rocm_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream)); } @@ -42,14 +44,15 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = rocm_kernel->GetScratchBuffer(k * ldb_aligned, ort_stream); + b_padded = rocm_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream)); } - RocmStream* ort_rocm_stream = static_cast(ort_stream); - auto handle = ort_rocm_stream->rocblas_handle_; + auto* ort_rocm_stream = dynamic_cast(ort_stream); + auto rocblas = ort_rocm_stream->rocblas_handle_; + ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex( - handle, + rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &alpha, diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h index a4adc3da98436..6be412348e6dd 100644 --- a/onnxruntime/core/providers/rocm/math/einsum.h +++ b/onnxruntime/core/providers/rocm/math/einsum.h @@ -17,8 +17,7 @@ class Einsum final : public onnxruntime::Einsum { Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method // TODO: Clean up the ROCMExecutionProvider interface to avoid this - rocm_ep_ = const_cast( - static_cast(info.GetExecutionProvider())); + rocm_ep_ = static_cast(info.GetExecutionProvider()); } Status Compute(OpKernelContext* context) const override; @@ -32,7 +31,7 @@ class Einsum final : public onnxruntime::Einsum { using onnxruntime::Einsum::equation_; // We need to access to the ROCM EP instance to get the rocblas/miopen handles - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; }; } // namespace rocm diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h index 623bb1d590a27..e1fc3f40ee9a5 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h @@ -21,19 +21,18 @@ namespace EinsumOp { // Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow struct EinsumRocmAssets { explicit EinsumRocmAssets(rocblas_handle rocblas_handle, - ROCMExecutionProvider* rocm_ep, - Stream* ort_stream, - AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), - rocm_ep_(rocm_ep), - ort_stream_(ort_stream), - gpu_allocator_(gpu_allocator) {} + const ROCMExecutionProvider* rocm_ep, + Stream* ort_stream, AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), + rocm_ep_(rocm_ep), + ort_stream_(ort_stream), + gpu_allocator_(gpu_allocator) {} hipStream_t GetRocmStream() { return ort_stream_ ? static_cast(ort_stream_->GetHandle()) : nullptr; } rocblas_handle rocblas_handle_; - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; Stream* ort_stream_; AllocatorPtr gpu_allocator_; }; diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 5a07737d92a02..bae1c000ddfcc 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace rocm { -template +template Status SoftMaxComputeHelper( Stream* stream, const T* X, @@ -29,20 +29,23 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - return dispatch_warpwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + return dispatch_warpwise_softmax_forward< + HipT_IN, HipT_OUT, AccumulationType_t, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); } - return dispatch_blockwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + + return dispatch_blockwise_softmax_forward, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N), tuning_ctx); } -#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); +#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) SPECIALIZED_SOFTMAX_HELPER_IMPL(float, float) diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h index 49bfddad36b43..0a5571bd57dde 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.h +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -11,7 +11,7 @@ namespace rocm { using tunable::RocmTuningContext; -template +template Status SoftMaxComputeHelper( Stream* stream, const T* input, @@ -20,14 +20,14 @@ Status SoftMaxComputeHelper( int64_t axis, RocmTuningContext* tuning_ctx = nullptr); -template -Status dispatch_warpwise_softmax_forward(Stream* stream, OutputT* dst, const InputT* src, int softmax_elements, - int softmax_elements_stride, int batch_count, +template +Status dispatch_warpwise_softmax_forward(Stream* stream, output_t* dst, const input_t* src, + int softmax_elements, int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx = nullptr); -template -Status dispatch_blockwise_softmax_forward(Stream* stream, OutputT* output, const InputT* input, int softmax_elements, - int input_stride, int output_stride, int batch_count, +template +Status dispatch_blockwise_softmax_forward(Stream* stream, output_t* output, const input_t* input, + int softmax_elements, int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx = nullptr); template diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6846813c7cb48..6214ec7bc0ea3 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -44,14 +44,13 @@ const miopenConvFwdAlgorithm_t Conv::kAllAlgos[] = { miopenConvolutionFwdAlgoWinograd, miopenConvolutionFwdAlgoImplicitGEMM}; -miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - miopenConvFwdAlgorithm_t algo, size_t* sz) { +miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, miopenConvFwdAlgorithm_t algo, size_t* sz) { return miopenConvolutionForwardGetWorkSpaceSize(handle, s.w_desc, s.x_tensor, s.conv_desc, s.y_tensor, sz); } size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, const miopenConvFwdAlgorithm_t* algo, int n_algo) { - // TODO: get maximum available size from memory arean + // TODO: get maximum available size from memory arena size_t free, total; HIP_CALL_THROW(hipMemGetInfo(&free, &total)); // Assuming 10% of fragmentation @@ -68,8 +67,7 @@ size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& input_dims, + const void* input_data, gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, @@ -103,8 +101,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); } // set B @@ -140,7 +137,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const size_t kernel_rank = kernel_shape.size(); - ConvAttributes::ConvPadVector pads(conv_attrs_.pads); + ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_rank * 2, 0); } @@ -174,7 +171,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) TensorShapeVector slice_axes; slice_axes.reserve(kernel_rank); - const size_t spatial_dim_start = channels_last ? 1 : 2; + constexpr size_t spatial_dim_start = channels_last ? 1 : 2; const size_t spatial_dim_end = spatial_dim_start + kernel_rank; TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); @@ -183,7 +180,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, post_slicing_required, slice_starts, slice_ends, slice_axes, channels_last)); - if (channels_last) { y_dims.push_back(M); y_dims_with_adjusted_pads.push_back(M); @@ -198,9 +194,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.slice_axes = slice_axes; s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } if (post_slicing_required) { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. s_.memory_for_miopen_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); @@ -225,18 +218,23 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } if (w_dims_changed) { - if (channels_last) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); + } else { ORT_RETURN_IF_ERROR(s_.w_desc.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, w_dims[0], w_dims[3], w_dims[1], w_dims[2])); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); } } + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (channels_last) { ORT_RETURN_IF_ERROR(s_.x_tensor.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, @@ -357,7 +355,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions // This may have lead to extra results that are unnecessary and hence we slice that off here if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size)); } @@ -384,18 +382,18 @@ MiopenConvolutionDescriptor::~MiopenConvolutionDescriptor() { Status MiopenConvolutionDescriptor::Set( size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type) { if (!desc_) MIOPEN_RETURN_IF_ERROR(miopenCreateConvolutionDescriptor(&desc_)); - InlinedVector pad_dims(rank); - InlinedVector stride_dims(rank); - InlinedVector dilation_dims(rank); + InlinedVector pad_dims(rank); + InlinedVector stride_dims(rank); + InlinedVector dilation_dims(rank); for (size_t i = 0; i < rank; i++) { pad_dims[i] = gsl::narrow_cast(pads[i]); stride_dims[i] = gsl::narrow_cast(strides[i]); diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index f4f2331e9197e..bc9846203e57d 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -10,6 +10,9 @@ #include namespace onnxruntime { + +using ConvPadVector = ConvAttributes::ConvPadVector; + namespace rocm { class MiopenConvolutionDescriptor final { @@ -18,9 +21,9 @@ class MiopenConvolutionDescriptor final { ~MiopenConvolutionDescriptor(); Status Set(size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type); @@ -198,7 +201,7 @@ class Conv : public RocmKernel { Status SliceOutUnwantedOutputSection(hipStream_t stream, const void* input_data, - const gsl::span& input_dims, + gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 475d26d2e306d..23e9faedb1e76 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -93,9 +93,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = gsl::make_span(y_dims); - if (w_dims_changed) { + if (w_dims_changed) ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); - } // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 4f726017d8b14..820745b22f614 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -8,6 +8,9 @@ #include "core/providers/rocm/math/binary_elementwise_ops_impl.h" #include "core/providers/rocm/math/binary_elementwise_ops.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" +#ifdef ENABLE_TRAINING +#include "contrib_ops/cpu/aten_ops/aten_op.h" +#endif using namespace onnxruntime::common; namespace onnxruntime { @@ -100,8 +103,8 @@ namespace rocm { (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// ROCM ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now. -#define REGISTER_KERNEL_TYPED_11(name, T) \ +// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet +#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -110,10 +113,10 @@ namespace rocm { kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 11, \ + 11, 11, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ @@ -166,7 +169,6 @@ Status ReduceKernel::ReduceKernelShared( const auto rank = input_shape.NumDimensions(); auto hip_stream = stream ? static_cast(stream->GetHandle()) : nullptr; - // Block of fast matrix reduction. if (fast_reduction_) { int m{}, n{}; @@ -210,10 +212,8 @@ Status ReduceKernel::ReduceKernelShared( ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, MiopenTensor::GetDataType(), ReduceTensorIndices)); else ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, ReduceTensorIndices)); - const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -444,17 +444,18 @@ template Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, Stream* ort_stream, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, const TensorShape* input_shape_override) { typedef typename ToHipType::MappedType HipT; const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); + hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; int64_t input_count = prepare_reduce_metadata.input_count; int64_t output_count = prepare_reduce_metadata.output_count; auto& output_dims = prepare_reduce_metadata.output_dims; auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // special case when there is a dim value of 0 in the shape. if (input_count == 0) { assert(output.Shape().Size() == 0); @@ -540,7 +541,6 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -588,11 +588,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); auto indices_rocm_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitRocmNotificationOnDevice); + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, indices_rocm_max.get(), indices_bytes_max, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } // Exp(X-ReduceMax) @@ -652,11 +653,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (input_count == output_count) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(reinterpret_cast(output.MutableData()), input_data, input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, input_data, - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } else { // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case @@ -675,11 +677,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, Impl_Cast(stream, temp_output.get(), reinterpret_cast(output.MutableData()), output_count); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } } @@ -743,18 +746,29 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR // empty axes and no-op if (axes.empty() && noop_with_empty_axes_) { auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream(ctx))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), + hipMemcpyDeviceToDevice, Stream(ctx))); return Status::OK(); } +#ifdef ENABLE_TRAINING + // Use ATen for ReduceSum if possible. + const TensorShape& input_shape = X->Shape(); + if (contrib::IsATenOperatorExecutorInitialized() && miopen_reduce_op == MIOPEN_REDUCE_TENSOR_ADD && !calculate_log_ && + !calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) { + if (axes.empty()) { + axes.resize(input_shape.NumDimensions()); + std::iota(axes.begin(), axes.end(), 0); + } + ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATen(ctx, axes, keepdims_)); + return Status::OK(); + } +#endif + PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); + ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, miopen_reduce_op, axes, calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); } @@ -837,7 +851,6 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(GetMiopenHandle(ctx), reduce_desc, indices_rocm.get(), indices_bytes, \ workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ &zero, output_tensor, temp_Y.get())); \ - \ Impl_Cast(Stream(ctx), temp_Y.get(), reinterpret_cast(Y->MutableData()), output_count); \ \ return Status::OK(); \ @@ -909,13 +922,13 @@ template std::unique_ptr ReduceCompute #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/platform/ort_mutex.h" @@ -56,7 +55,7 @@ class ROCMPinnedAllocator : public IAllocator { ROCMPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device always with id 0*/), 0, OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 730f55608c725..484e59f4de7d8 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -39,11 +39,11 @@ const char* RocmErrString(rocblas_status e) { CASE_ENUM_TO_STR(rocblas_status_invalid_handle); CASE_ENUM_TO_STR(rocblas_status_not_implemented); CASE_ENUM_TO_STR(rocblas_status_invalid_pointer); + CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_invalid_size); CASE_ENUM_TO_STR(rocblas_status_memory_error); CASE_ENUM_TO_STR(rocblas_status_internal_error); CASE_ENUM_TO_STR(rocblas_status_perf_degraded); - CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_size_increased); CASE_ENUM_TO_STR(rocblas_status_size_unchanged); CASE_ENUM_TO_STR(rocblas_status_invalid_value); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index c9975d0bc76c0..3c106313a89cd 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/rocm/rocm_execution_provider.h" @@ -9,7 +10,6 @@ #include "core/providers/rocm/rocm_fwd.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_profiler.h" -#include "core/providers/rocm/rocm_stream_handle.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/rocm/rocm_contrib_kernels.h" @@ -23,6 +23,8 @@ #include "core/providers/rocm/triton_kernel.h" #endif +#include "core/providers/rocm/rocm_stream_handle.h" + using namespace onnxruntime::common; namespace onnxruntime { @@ -38,42 +40,64 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); - return gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream()); - } else if (X_type->IsSparseTensorType()) { - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); - return X->Copy(Info().GetDataTransferManager(), *Y); - } else if (X_type->IsTensorSequenceType()) { - const TensorSeq* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); - TensorSeq* Y = ctx->Output(0); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); - auto X_dtype = X->DataType(); - Y->SetType(X_dtype); - AllocatorPtr alloc; - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy rocm: unable to get an allocator."); - } - auto X_size = X->Size(); - Y->Reserve(X_size); - for (size_t i = 0; i < X_size; ++i) { - const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - Status retval = gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream()); - if (!retval.IsOK()) { - return retval; + // do we support async copy? + // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, + // so we don't need the check here. + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); + return Status::OK(); + } else { + if (X_type->IsSparseTensorType()) { + // TODO: support aysnc copy for sparse tensor + // sync the stream first, since it is a sync memory copy + HIP_CALL_THROW(hipStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()))); + const auto* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); + return X->Copy(Info().GetDataTransferManager(), *Y); + } else if (X_type->IsTensorSequenceType()) { + const TensorSeq* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); + TensorSeq* Y = ctx->Output(0); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); + auto X_dtype = X->DataType(); + Y->SetType(X_dtype); + AllocatorPtr alloc; + + // If we are copying contents to ROCM, the allocator to use + // to allocate the buffers of the new tensors in the sequence + // can be temp space allocator associated with the ROCM EP + if (Node().OpType() == "MemcpyFromHost") { + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get an allocator."); + } + } else { + // If we are copying contents to CPU (op type is "MemcpyToHost"), + // the allocator to use to allocate the buffers of the new tensors + // in the sequence will be the allocator from the CPU EP + auto status = ctx->GetTempSpaceCPUAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get the CPU allocator."); + } + } + auto X_size = X->Size(); + Y->Reserve(X_size); + for (size_t i = 0; i < X_size; ++i) { + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, + target_tensor->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); + Y->Add(std::move(*target_tensor)); } - Y->Add(std::move(*target_tensor)); + return Status::OK(); } - return Status::OK(); + return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } - return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } }; @@ -100,18 +124,23 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace rocm -AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) { +AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, + size_t gpu_mem_limit, + ArenaExtendStrategy arena_extend_strategy, + ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, + const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, external_allocator_info.alloc, external_allocator_info.free, external_allocator_info.empty_cache); + return std::make_unique(id, HIP, + external_allocator_info.alloc, + external_allocator_info.free, + external_allocator_info.empty_cache); }, device_id, false); return CreateAllocator(default_memory_info); - } else { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId id) { @@ -120,12 +149,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, - static_cast(arena_extend_strategy), - -1, - -1, - -1, - -1)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -149,20 +173,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { - // dtor shouldn't throw. if something went wrong earlier (e.g. out of ROCM memory) the handles - // here may be bad, and the destroy calls can throw. - // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-noexcept - try { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "rocblas_destroy_handle threw:" << ex.what(); - } - - try { - MIOPEN_CALL_THROW(miopenDestroy(miopen_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "miopenDestroy threw:" << ex.what(); - } + ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_))); + ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { @@ -235,7 +247,7 @@ ROCMExecutionProvider::~ROCMExecutionProvider() { } if (!external_stream_ && stream_) { - HIP_CALL_THROW(hipStreamDestroy(stream_)); + ORT_IGNORE_RETURN_VALUE(HIP_CALL(hipStreamDestroy(stream_))); } } @@ -315,7 +327,7 @@ Status ROCMExecutionProvider::OnRunStart() { Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream) { - HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), @@ -716,12 +728,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -774,7 +786,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 18, Scan); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Slice); @@ -827,12 +839,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot); @@ -1087,6 +1097,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -1105,17 +1126,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -1186,12 +1196,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, 18, Shape); // Opset 16 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LeakyRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); @@ -1260,977 +1271,929 @@ KernelCreateInfo BuildKernelCreateInfo() { static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 10 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 12 - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 13 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 14 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // OpSet 15 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 16 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 17 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - // Opset 18 - BuildKernelCreateInfo, - - // Opset 19 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // opset 10 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // opset 11 + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 12 + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 13 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 14 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // OpSet 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 16 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 17 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 18 + BuildKernelCreateInfo, + + // Opset 19 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -2336,7 +2299,6 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // These are usually shape related computation subgraphs // Following logic can be extended for other EPs auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); - std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) @@ -2371,7 +2333,7 @@ OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) cons std::vector ROCMExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId device_id) { + [](OrtDevice::DeviceId) { return std::make_unique(HIP_PINNED); }, // TODO: should we use info_.device_id instead of DEFAULT_CPU_ALLOCATOR_DEVICE_ID? @@ -2383,7 +2345,8 @@ std::vector ROCMExecutionProvider::CreatePreferredAllocators() { return std::vector{ CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg), - CreateAllocator(pinned_memory_info)}; + CreateAllocator(pinned_memory_info), + }; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 3e86afb7d643c..c4945b9ac2481 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -36,11 +36,11 @@ class ROCMExecutionProvider : public IExecutionProvider { return nullptr; } - rocblas_handle PerThreadRocblasHandle() { + rocblas_handle PerThreadDefaultRocblasHandle() { return GetPerThreadContext().RocblasHandle(); } - miopenHandle_t PerThreadMiopenHandle() { + miopenHandle_t PerThreadDefaultMiopenHandle() { return GetPerThreadContext().MiopenHandle(); } @@ -60,7 +60,6 @@ class ROCMExecutionProvider : public IExecutionProvider { const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; int GetMiopenConvExhaustiveSearch() const { return info_.miopen_conv_exhaustive_search; } bool DoCopyOnDefaultStream() const { return info_.do_copy_in_default_stream; } - bool GetMiopenConvUseMaxWorkspace() const { return info_.miopen_conv_use_max_workspace; } ProviderOptions GetProviderOptions() const override { @@ -68,15 +67,15 @@ class ROCMExecutionProvider : public IExecutionProvider { } static AllocatorPtr CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); + ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); ITuningContext* GetTuningContext() const override; std::unique_ptr GetProfiler() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - std::vector CreatePreferredAllocators() override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + std::vector CreatePreferredAllocators() override; private: ROCMExecutionProviderInfo info_; @@ -105,21 +104,30 @@ class ROCMExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, hipStream_t stream) { - if (std::is_same::value) { + constexpr bool is_float = std::is_same::value; + constexpr bool is_double = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_BFloat16 = std::is_same::value; + if (is_float) { if (!constant_ones_float_) { constant_ones_float_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_double) { if (!constant_ones_double_) { constant_ones_double_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_half) { if (!constant_ones_half_) { constant_ones_half_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); + } else if (is_BFloat16) { + if (!constant_ones_bfloat16_) { + constant_ones_bfloat16_ = rocm::CreateConstantOnes(); + } + return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); } else { return nullptr; } @@ -132,6 +140,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_float_; std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; + std::unique_ptr> constant_ones_bfloat16_; }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 91e3aaaa4280f..650635c153640 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -27,12 +27,10 @@ constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_dur } // namespace provider_option_names } // namespace rocm -namespace { const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -} // namespace ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { ROCMExecutionProviderInfo info{}; @@ -81,7 +79,9 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToEnumReference( rocm::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) - .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvExhaustiveSearch, info.miopen_conv_exhaustive_search) + .AddAssignmentToReference( + rocm::provider_option_names::kMiopenConvExhaustiveSearch, + info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) .AddValueParser( diff --git a/onnxruntime/core/providers/rocm/rocm_fwd.h b/onnxruntime/core/providers/rocm/rocm_fwd.h deleted file mode 100644 index b123446fa9be1..0000000000000 --- a/onnxruntime/core/providers/rocm/rocm_fwd.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace rocm { -template -KernelCreateInfo BuildKernelCreateInfo(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 02f15fdad8b77..463c1cf0d2ea6 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -35,14 +35,12 @@ class RocmKernel : public OpKernel { // use this to precisely locate the node where ROCM failure comes from // if (hipSuccess != hipDeviceSynchronize()) // __debugbreak(); - if (s.IsOK()) { auto err = hipGetLastError(); if (err != hipSuccess) { - s = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); } } - return s; } @@ -64,18 +62,18 @@ class RocmKernel : public OpKernel { return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, true); } - template - inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); - } - inline void AddDeferredReleaseCPUPtr(void* p, onnxruntime::Stream* ort_stream) const { ORT_ENFORCE(ort_stream->GetDevice().Type() == OrtDevice::GPU); auto* rocm_ep_stream = static_cast(ort_stream); rocm_ep_stream->EnqueDeferredCPUBuffer(p); } + template + inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { + if (count_or_bytes == 0) return nullptr; + return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); + } + const hipDeviceProp_t& GetDeviceProp() const { return provider_->GetDeviceProp(); } inline hipStream_t Stream(OpKernelContext* ctx) const { @@ -83,6 +81,22 @@ class RocmKernel : public OpKernel { return stream ? static_cast(stream->GetHandle()) : nullptr; } + inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { + return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + } + + static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { + return stream->miopen_handle_; + } + + inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { + return GetRocblasHandle(static_cast(ctx->GetComputeStream())); + } + + static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { + return stream->rocblas_handle_; + } + tunable::RocmTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } @@ -106,7 +120,7 @@ class RocmKernel : public OpKernel { } } - RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { + RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); } @@ -151,28 +165,12 @@ class RocmKernel : public OpKernel { const RocmKernel* op_kernel_; }; - inline rocblas_handle RocblasHandle() const { - return provider_->PerThreadRocblasHandle(); - } - - inline miopenHandle_t MiopenHandle() const { - return provider_->PerThreadMiopenHandle(); + inline rocblas_handle DefaultRocblasHandle() const { + return provider_->PerThreadDefaultRocblasHandle(); } - static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { - return stream->rocblas_handle_; - } - - inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { - return GetRocblasHandle(static_cast(ctx->GetComputeStream())); - } - - static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { - return stream->miopen_handle_; - } - - inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { - return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + inline miopenHandle_t DefaultMiopenHandle() const { + return provider_->PerThreadDefaultMiopenHandle(); } protected: diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index e55b2edbad685..f59e609b8f355 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -3,15 +3,13 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/rocm/rocm_provider_factory.h" - -#include +#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/common/gsl.h" #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" #include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" @@ -47,7 +45,7 @@ std::unique_ptr ROCMProviderFactory::CreateProvider() { return std::make_unique(info_); } -struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { +struct ProviderInfo_ROCM_Impl final : ProviderInfo_ROCM { OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) override { int num_devices; auto hip_err = ::hipGetDeviceCount(&num_devices); @@ -128,9 +126,26 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { } // Used by slice_concatenate_test.cc and onnxruntime_pybind_state.cc - void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); } + + void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { + // hipMemcpy() operates on the default stream + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); + + // To ensure that the copy has completed, invoke a stream sync for the default stream. + // https://docs.nvidia.com/rocm/rocm-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync + // For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. + // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer + // to device memory, but the DMA to final destination may not have completed. + + HIP_CALL_THROW(hipStreamSynchronize(0)); + } + // Used by onnxruntime_pybind_state.cc - void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } + void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { + // https://docs.nvidia.com/rocm/rocm-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync + // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); + } int hipGetDeviceCount() override { int num_devices = 0; @@ -152,10 +167,9 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { return std::make_shared(info); } - std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) override { + std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { return ROCMExecutionProvider::CreateRocmAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } - } g_info; struct ROCM_Provider : Provider { @@ -169,8 +183,8 @@ struct ROCM_Provider : Provider { info.gpu_mem_limit = params->gpu_mem_limit; info.arena_extend_strategy = static_cast(params->arena_extend_strategy); info.miopen_conv_exhaustive_search = params->miopen_conv_exhaustive_search; - info.do_copy_in_default_stream = params->do_copy_in_default_stream; - info.has_user_compute_stream = params->has_user_compute_stream; + info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0; + info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; info.tunable_op.enable = params->tunable_op_enable; @@ -180,21 +194,32 @@ struct ROCM_Provider : Provider { return std::make_shared(info); } + /** + * This function will be called by the C API UpdateROCMProviderOptions(). + * + * What this function does is equivalent to resetting the OrtROCMProviderOptions instance with + * default ROCMExecutionProviderInf instance first and then set up the provided provider options. + * See ROCMExecutionProviderInfo::FromProviderOptions() for more details. + */ void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto info = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); + auto internal_options = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); auto& rocm_options = *reinterpret_cast(provider_options); - rocm_options.device_id = info.device_id; - rocm_options.gpu_mem_limit = info.gpu_mem_limit; - rocm_options.arena_extend_strategy = static_cast(info.arena_extend_strategy); - rocm_options.miopen_conv_exhaustive_search = info.miopen_conv_exhaustive_search; - rocm_options.do_copy_in_default_stream = info.do_copy_in_default_stream; - rocm_options.has_user_compute_stream = info.has_user_compute_stream; - rocm_options.user_compute_stream = info.user_compute_stream; - rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg; - rocm_options.tunable_op_enable = info.tunable_op.enable; - rocm_options.tunable_op_tuning_enable = info.tunable_op.tuning_enable; - rocm_options.tunable_op_max_tuning_duration_ms = info.tunable_op.max_tuning_duration_ms; + rocm_options.device_id = internal_options.device_id; + rocm_options.gpu_mem_limit = internal_options.gpu_mem_limit; + rocm_options.arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + rocm_options.miopen_conv_exhaustive_search = internal_options.miopen_conv_exhaustive_search; + rocm_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; + rocm_options.has_user_compute_stream = internal_options.has_user_compute_stream; + // The 'has_user_compute_stream' of the OrtROCMProviderOptions instance can be set byC API UpdateROCMProviderOptionsWithValue() as well. + // We only set the 'has_user_compute_stream' of the OrtROCMProviderOptions instance if it is provided in options + if (options.find("has_user_compute_stream") != options.end()) { + rocm_options.user_compute_stream = internal_options.user_compute_stream; + } + rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.tunable_op_enable = internal_options.tunable_op.enable; + rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; + rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; } ProviderOptions GetProviderOptions(const void* provider_options) override { @@ -219,4 +244,6 @@ extern "C" { ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } + } + diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.h b/onnxruntime/core/providers/rocm/rocm_provider_factory.h index 8cd7bd357330f..80b887af4eb75 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.h +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.h @@ -3,6 +3,7 @@ #include "onnxruntime_c_api.h" #include "core/framework/provider_options.h" +#include "core/common/common.h" namespace onnxruntime { class IAllocator; @@ -43,7 +44,16 @@ struct ProviderInfo_ROCM { #endif virtual std::shared_ptr CreateExecutionProviderFactory(const onnxruntime::ROCMExecutionProviderInfo& info) = 0; - virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + + // This function is the entry point to ROCM EP's UT cases. + // All tests ared only called from onnxruntime_test_all. + virtual void TestAll() { + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is only implements in test code path."); + } + + protected: + ~ProviderInfo_ROCM() = default; // Can only be destroyed through a subclass instance }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index 0d9877e6b18e6..670aae91ca710 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -1,7 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/rocm/rocm_resource.h" #include "core/providers/rocm/rocm_stream_handle.h" #include "core/providers/rocm/rocm_common.h" // #include "core/common/spin_pause.h" -#include "core/providers/rocm/rocm_resource.h" namespace onnxruntime { @@ -82,15 +84,29 @@ void RocmStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { deferred_cpu_buffers_.push_back(cpu_buffer); } -struct CpuBuffersInfo { // TODO: should be moved to base class +struct CpuBuffersInfo { + // This struct stores the information needed + // to release CPU buffers allocated for GPU kernels. + // It's used to enqueue their release after + // associated GPU kernels in a ROCM stream. + + // This is a CPU allocator in ROCM EP. + // It must be the one used to allocate the + // following pointers. AllocatorPtr allocator; + // buffers[i] is the i-th pointer added by + // AddDeferredReleaseCPUPtr for a specific + // ROCM stream. For example, this fields + // should contain all values in + // deferred_release_buffer_pool_[my_stream] + // when release my_stream's buffers. std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". size_t n_buffers; }; -static void ReleaseCpuBufferCallback(hipStream_t /*stream*/, hipError_t /*status*/, void* raw_info) { // TODO: should be moved to base class +static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); info.reset(reinterpret_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { @@ -111,14 +127,7 @@ Status RocmStream::CleanUpOnRunEnd() { cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); } cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); - // TODO(wechi): CUDA deprecates cudaStreamAddCallback and - // uses another API, cudaLaunchHostFunc(which can be - // captured in CUDA graph). Once AMD adds similar feature, - // we should replace the following line with - // hipLaunchHostFunc(stream, ReleaseCpuBufferCallback, cpu_buffers_info); - - // Release memory asynchronously to avoid blocking the compute stream. - HIP_RETURN_IF_ERROR(hipStreamAddCallback(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release(), 0)); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); } else { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); for (auto* buffer : deferred_cpu_buffers_) { @@ -130,10 +139,10 @@ Status RocmStream::CleanUpOnRunEnd() { return Status::OK(); } -void* RocmStream::GetResource(int version, int type) const { +void* RocmStream::GetResource(int version, int id) const { ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!"); void* resource{}; - switch (type) { + switch (id) { case RocmResource::hip_stream_t: return reinterpret_cast(GetHandle()); break; @@ -149,6 +158,7 @@ void* RocmStream::GetResource(int version, int type) const { return resource; } +// CPU Stream command handles void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_device(stream); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 865cff0abf85f..1f3e5b75548e7 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/providers/rocm/rocm_pch.h" // #include "core/providers/cuda/shared_inc/cuda_utils.h" @@ -17,14 +20,12 @@ struct RocmStream : Stream { ~RocmStream(); - std::unique_ptr CreateNotification(size_t num_consumers) override; + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; void Flush() override; Status CleanUpOnRunEnd() override; - void* GetResource(int version, int id) const override; - void EnqueDeferredCPUBuffer(void* cpu_buffer); bool own_stream_{true}; @@ -33,6 +34,8 @@ struct RocmStream : Stream { rocblas_handle rocblas_handle_{}; + void* GetResource(int version, int id) const override; + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/rocm/rocm_utils.cu b/onnxruntime/core/providers/rocm/rocm_utils.cu index cbf410e78a4a9..b817e025cedf4 100644 --- a/onnxruntime/core/providers/rocm/rocm_utils.cu +++ b/onnxruntime/core/providers/rocm/rocm_utils.cu @@ -30,13 +30,14 @@ template void Fill(hipStream_t stream, T* output, T value, int64_t count) { int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); HIP_LONG N = static_cast(count); - _Fill<<>>(output, value, N); + _Fill + <<>>(output, value, N); } template class ConstantBufferImpl : public IConstantBuffer { public: - ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) {} - + ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) { + } ~ConstantBufferImpl() { if (buffer_) HIP_CALL_THROW(hipFree(buffer_)); @@ -70,6 +71,7 @@ std::unique_ptr> CreateConstantOnes() { template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); +template std::unique_ptr> CreateConstantOnes(); #define SPECIALIZED_FILL(T) \ template void Fill(hipStream_t stream, T * output, T value, int64_t count); @@ -81,6 +83,7 @@ SPECIALIZED_FILL(int64_t) SPECIALIZED_FILL(float) SPECIALIZED_FILL(double) SPECIALIZED_FILL(__half) +SPECIALIZED_FILL(BFloat16) } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h b/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h deleted file mode 100644 index 83ca0a443c4fa..0000000000000 --- a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h +++ /dev/null @@ -1,90 +0,0 @@ -// -// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved -// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. -// - -#pragma once - -#include -#include -#include -#include -#include "core/common/common.h" - -namespace onnxruntime { -namespace rocm { - -// DivMod is a helper class for integer division and modulo operation. -// There is a fast version for int type and a slow version for other type. -template -struct DivMod { - DivMod(T d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= std::numeric_limits::max()); - } - - __host__ __device__ inline T div(T n) const { - return n / d_; - } - - __host__ __device__ inline T mod(T n) const { - return n % d_; - } - - __host__ __device__ inline void divmod(T n, T& q, T& r) const { - q = div(n); - r = n - q * d_; - } - - T d_; // divisor -}; - -// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf -// In current ORT, fast_divmod is used for calculating the position of a element in tensor, -// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple, -// then GPU compiler can do loop unroll easilly when divmod is called in a loop. -template <> -struct DivMod { - DivMod(int d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= static_cast(std::numeric_limits::max())); - - for (l_ = 0; l_ < 32; l_++) - if ((1U << l_) >= d_) break; - - uint64_t one = 1; - uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; - M_ = static_cast(m); - // according to paper, the value of m' should fit in a unsigned integer. - ORT_ENFORCE(M_ > 0 && M_ == m); - } - - __host__ __device__ inline int div(int n) const { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - uint32_t t = __umulhi(M_, n); - return (t + n) >> l_; -#else - // Using uint64_t for t, then t + n won't overflow. - uint64_t t = ((uint64_t)M_ * n) >> 32; - return static_cast((t + n) >> l_); -#endif - } - - __host__ __device__ inline int mod(int n) const { - return n - div(n) * d_; - } - - __host__ __device__ inline void divmod(int n, int& q, int& r) const { - q = div(n); - r = n - q * d_; - } - - uint32_t d_; // divisor - uint32_t M_; // m' in the paper. - uint32_t l_; // l_ = ceil(log2(d_)) -}; - -using fast_divmod = DivMod; // Keep the old name for backward compatibility. - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index d6623ef63f0fd..b6b40666b8bd0 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,16 +17,20 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define MIOPEN_CALL(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) + #define HIPFFT_CALL(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) + #define MIOPEN_CALL_THROW(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL_THROW2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) #define HIPFFT_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 15e2449cd230f..fff103e5fa339 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -197,9 +197,8 @@ TEST(BiasGeluTest, BFloat16) { } #endif +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, ComplexMul) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -219,13 +218,15 @@ TEST(MathOpTest, ComplexMul) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -245,13 +246,15 @@ TEST(MathOpTest, ComplexMulConj) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMul_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -271,13 +274,15 @@ TEST(MathOpTest, ComplexMul_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -297,9 +302,14 @@ TEST(MathOpTest, ComplexMulConj_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc index eaadb95c8a0c0..7092f41da3443 100644 --- a/onnxruntime/test/contrib_ops/fft_op_test.cc +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -8,7 +8,14 @@ namespace onnxruntime { namespace test { TEST(ContribOpTest, Rfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#else + return; +#endif OpTester test("Rfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -17,13 +24,18 @@ TEST(ContribOpTest, Rfft) { // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward") test.AddInput("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); test.AddOutput("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(ContribOpTest, Irfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#else + return; +#endif OpTester test("Irfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -31,8 +43,6 @@ TEST(ContribOpTest, Irfft) { test.AddAttribute("normalized", static_cast(0)); test.AddInput("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); test.AddOutput("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index f5259c1391f38..cf643fdf73753 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -50,11 +50,16 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; +#ifdef USE_CUDA constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { + if (!HasCudaEnvironment(min_cuda_architecture)) return; +#endif + { Ort::SessionOptions session_options; #ifdef USE_CUDA Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#elif USE_ROCM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); #endif // The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx @@ -117,11 +122,16 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; +#ifdef USE_CUDA constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { + if (!HasCudaEnvironment(min_cuda_architecture)) return; +#endif + { Ort::SessionOptions session_options; #ifdef USE_CUDA Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#elif USE_ROCM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); #endif Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_greedysearch_with_init_decoder.onnx"), session_options); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e0293128045a5..6f492317524be 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -150,6 +150,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): # CUFFT -> HIPFFT s = s.replace("CUFFT", "HIPFFT") + s = s.replace("cufftXtMakePlanMany", "hipfftXtMakePlanMany") + s = s.replace("cufftXtExec", "hipfftXtExec") # Undo where above hipify steps went too far. s = s.replace("id, ROCM", "id, CUDA") # cuda_execution_provider.cc @@ -169,6 +171,24 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") + s = s.replace("#include ", "#include ") + s = s.replace('#include "hipfft.h"', "#include ") + s = s.replace('#include "hipfftXt.h"', "#include ") + + # Fix onnxruntime/contrib_ops/rocm/transformers. They include cpu headers which use "cuda" in their names. + s = s.replace("rocm_device_prop_", "cuda_device_prop_") + s = s.replace("rocm_device_arch_", "cuda_device_arch_") + + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names + # And we do this last, undoing or fixing hipify mistakes. + if "fft" in src_file_path: + s = s.replace("rocblas_datatype", "hipDataType") + s = s.replace("hipDataType_f32_c", "HIP_C_32F") + s = s.replace("hipDataType_f32_r", "HIP_R_32F") + s = s.replace("hipDataType_f64_c", "HIP_C_64F") + s = s.replace("hipDataType_f64_r", "HIP_R_64F") + s = s.replace("hipDataType_f16_c", "HIP_C_16F") + s = s.replace("hipDataType_f16_r", "HIP_R_16F") with open(dst_file_path, "w") as f: f.write(s)