Skip to content

Commit

Permalink
CUDA EP vs ROCM EP hipify audit (#19)
Browse files Browse the repository at this point in the history
- hipify audit of onnxruntime/core/providers/rocm
- hipify audit of onnxruntime/contrib_ops/rocm
- fix contrib ops search implementation
- enable more contrib ops
  - Affine
  - ComplexMul 
  - ConvTransposeWithDynamicPads
  - Crop
  - DynamicSlice
  - FFT [Rfft, Irfft]
  - GreedySearch
  - ImageScaler
  - ParametricSoftplus
  - ScaledTanh
  - ThresholdRelu
  • Loading branch information
jeffdaily authored Oct 3, 2023
1 parent a6cb1a5 commit 217ecdd
Show file tree
Hide file tree
Showing 46 changed files with 1,501 additions and 1,504 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,7 @@ if (onnxruntime_USE_ROCM)
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(MIOpen REQUIRED)
find_package(hipfft REQUIRED)

# MIOpen version
if(NOT DEFINED ENV{MIOPEN_PATH})
Expand Down Expand Up @@ -1548,7 +1549,7 @@ if (onnxruntime_USE_ROCM)

find_library(RCCL_LIB rccl REQUIRED)
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB})
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB})

file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h"
Expand Down
27 changes: 0 additions & 27 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
"math/complex_mul_impl.h"
"math/cufft_plan_cache.h"
"math/fft_ops.cc"
"math/fft_ops.h"
"math/fft_ops_impl.cu"
"math/fft_ops_impl.h"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down Expand Up @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files
"quantization/qordered_ops/qordered_unary_ops.cc"
"quantization/qordered_ops/qordered_unary_ops_impl.h"
"quantization/qordered_ops/qordered_unary_ops_impl.cu"
"tensor/crop.cc"
"tensor/crop.h"
"tensor/crop_impl.cu"
"tensor/crop_impl.h"
"tensor/dynamicslice.cc"
"tensor/image_scaler.cc"
"tensor/image_scaler.h"
"tensor/image_scaler_impl.cu"
"tensor/image_scaler_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"conv_transpose_with_dynamic_pads.cc"
"conv_transpose_with_dynamic_pads.h"
"cuda_contrib_kernels.cc"
"cuda_contrib_kernels.h"
"inverse.cc"
Expand All @@ -114,10 +92,6 @@ endif()

set(provider_excluded_files
"atomic/common.cuh"
"controlflow/loop.cc"
"controlflow/loop.h"
"controlflow/scan.cc"
"controlflow/scan.h"
"cu_inc/common.cuh"
"math/einsum_utils/einsum_auxiliary_ops.cc"
"math/einsum_utils/einsum_auxiliary_ops.h"
Expand Down Expand Up @@ -165,7 +139,6 @@ set(provider_excluded_files
"cuda_memory_check.h"
"cuda_fence.cc"
"cuda_fence.h"
"cuda_fwd.h"
"cuda_kernel.h"
"cuda_pch.cc"
"cuda_pch.h"
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -240,7 +240,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_int32_func_,
update_gpt_feeds_fp16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down Expand Up @@ -271,7 +271,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -293,7 +293,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -320,7 +320,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -341,7 +341,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class BeamSearch : public IControlFlowKernel {
create_beam_scorer_func_ = create_beam_scorer_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func) {
Expand Down Expand Up @@ -96,7 +96,7 @@ class BeamSearch : public IControlFlowKernel {
expand_buffer_float16_func_ = expand_buffer_float16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
#endif
Expand All @@ -115,7 +115,7 @@ class BeamSearch : public IControlFlowKernel {
GenerationDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
update_feeds_func_(update_feeds_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const void* cuda_device_prop,
Expand Down Expand Up @@ -100,7 +100,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
Expand Down Expand Up @@ -336,7 +336,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// Increase sequence length after a new token is generated.
++current_length;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedSelfAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
Expand Down Expand Up @@ -87,7 +87,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
// Device specific functions
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down Expand Up @@ -280,7 +280,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
Expand Down Expand Up @@ -85,7 +85,7 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
// Device specific functions
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down Expand Up @@ -272,7 +272,7 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum DeviceCopyDirection {

namespace GenerationDeviceHelper {

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
using ReorderPastStateFunc = std::function<Status(
const void* cuda_device_prop, // cudaDeviceProp
Tensor& past_state,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -227,7 +227,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
init_greedy_state_fp16_func_,
device_copy_func_,
update_gpt_feeds_fp16_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/greedy_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GreedySearch : public IControlFlowKernel {
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) {
reorder_past_state_func_ = reorder_past_state_func;
}
Expand All @@ -73,7 +73,7 @@ class GreedySearch : public IControlFlowKernel {
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
#endif
Expand All @@ -90,7 +90,7 @@ class GreedySearch : public IControlFlowKernel {
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
init_greedy_state_func_(init_greedy_state_func),
update_feeds_func_(update_feeds_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const void* cuda_device_prop,
Expand Down Expand Up @@ -109,7 +109,7 @@ class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitGreedyStateFunc<T> init_greedy_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
Expand Down Expand Up @@ -336,7 +336,7 @@ Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_
// Increase sequence length after a new token is generated.
++current_length;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedSelfAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const {
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -163,7 +163,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const {
init_greedy_state_fp16_func_,
device_copy_func_,
update_gpt_feeds_fp16_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Sampling : public IControlFlowKernel {
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) {
reorder_past_state_func_ = reorder_past_state_func;
}
Expand All @@ -70,7 +70,7 @@ class Sampling : public IControlFlowKernel {
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* gpu_device_prop_ = nullptr;
int gpu_device_arch_ = 0;
#endif
Expand All @@ -87,7 +87,7 @@ class Sampling : public IControlFlowKernel {
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif

Expand Down
Loading

0 comments on commit 217ecdd

Please sign in to comment.