Skip to content

Commit

Permalink
[ROCm] enable hipGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdaily committed Nov 9, 2023
1 parent aed43f4 commit 24f7d47
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 30 deletions.
4 changes: 2 additions & 2 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445
eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb
eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a
flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
Expand Down Expand Up @@ -44,4 +44,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
Expand Down Expand Up @@ -546,6 +547,8 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;

int enable_hip_graph;

/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.
Expand Down
77 changes: 72 additions & 5 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de

MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_));
MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream));

hip_graph_.SetStream(stream);
}

ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
Expand All @@ -165,6 +167,33 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
}
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_;
}

void ROCMExecutionProvider::PerThreadContext::CaptureBegin() {
hip_graph_.Reset();
hip_graph_.CaptureBegin();
}

void ROCMExecutionProvider::PerThreadContext::CaptureEnd() {
hip_graph_.CaptureEnd();
is_graph_captured_ = true;
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
}

Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return hip_graph_.Replay();
}

void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
}

void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) {
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
Expand Down Expand Up @@ -207,6 +236,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in
if (info.external_allocator_info.UseExternalAllocator()) {
use_ep_level_unified_stream_ = true;
stream_ = nullptr;
} else if (info.enable_hip_graph) {
// current hip graph implementation only works with single stream
// use EP level unified stream for all the reqeust
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
use_ep_level_unified_stream_ = true;
} else {
stream_ = nullptr;
}
Expand Down Expand Up @@ -310,25 +344,58 @@ Status ROCMExecutionProvider::Sync() const {
Status ROCMExecutionProvider::OnRunStart() {
// always set ROCM device when session::Run() in case it runs in a worker thread
HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
}

Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
// HIP work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
}

if (sync_stream) {
HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_));
}

// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
// The reason of !IsGraphCaptureEnabled():
// If hip graph is enabled, the per thread context will not be released
// because the per thread hip graph needs to be maintained and replayed for
// the next run.
// The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end():
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (!IsGraphCaptureEnabled() &&
PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
ReleasePerThreadContext();
}

return Status::OK();
}

bool ROCMExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_hip_graph;
}

bool ROCMExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
}

Status ROCMExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
}

namespace rocm {
// opset 1 to 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
Expand Down
35 changes: 31 additions & 4 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/providers/rocm/rocm_graph.h"
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand Down Expand Up @@ -74,14 +75,18 @@ class ROCMExecutionProvider : public IExecutionProvider {

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

private:
ROCMExecutionProviderInfo info_;
hipDeviceProp_t device_prop_;
bool external_stream_ = false;
// only used when set user external stream or hip graph
hipStream_t stream_ = nullptr;

bool use_ep_level_unified_stream_ = false;
Expand All @@ -105,17 +110,20 @@ class ROCMExecutionProvider : public IExecutionProvider {

template <typename T>
const T* GetConstOnes(size_t count, hipStream_t stream) {
if (std::is_same<T, float>::value) {
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_double = std::is_same<T, double>::value;
constexpr bool is_half = std::is_same<T, half>::value;
if (is_float) {
if (!constant_ones_float_) {
constant_ones_float_ = rocm::CreateConstantOnes<float>();
}
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(stream, count));
} else if (std::is_same<T, double>::value) {
} else if (is_double) {
if (!constant_ones_double_) {
constant_ones_double_ = rocm::CreateConstantOnes<double>();
}
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(stream, count));
} else if (std::is_same<T, half>::value) {
} else if (is_half) {
if (!constant_ones_half_) {
constant_ones_half_ = rocm::CreateConstantOnes<half>();
}
Expand All @@ -125,13 +133,32 @@ class ROCMExecutionProvider : public IExecutionProvider {
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
void IncrementRegularRunCountBeforeGraphCapture();

private:
rocblas_handle rocblas_handle_ = nullptr;
miopenHandle_t miopen_handle_ = nullptr;

std::unique_ptr<rocm::IConstantBuffer<float>> constant_ones_float_;
std::unique_ptr<rocm::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<rocm::IConstantBuffer<half>> constant_ones_half_;

// Hip graph with multi threads will be supported in the future, so hip_graph_
// is put under PerThreadContext.
ROCMGraph hip_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.
};

using PerThreadContextMap = std::unordered_map<const ROCMExecutionProvider*, std::weak_ptr<PerThreadContext>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc";
constexpr const char* kGpuExternalFree = "gpu_external_free";
constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace";
constexpr const char* kEnableHipGraph = "enable_hip_graph";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms";
Expand Down Expand Up @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvExhaustiveSearch, info.miopen_conv_exhaustive_search)
.AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace)
.AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph)
.AddValueParser(
rocm::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)},
{rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)},
{rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
{rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo {
// If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best.
bool miopen_conv_use_max_workspace{true};

bool enable_hip_graph{false};

rocm::TunableOpInfo tunable_op{};

static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.enable_hip_graph = params->enable_hip_graph;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms;
Expand All @@ -192,6 +193,7 @@ struct ROCM_Provider : Provider {
rocm_options.has_user_compute_stream = info.has_user_compute_stream;
rocm_options.user_compute_stream = info.user_compute_stream;
rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg;
rocm_options.enable_hip_graph = info.enable_hip_graph;
rocm_options.tunable_op_enable = info.tunable_op.enable;
rocm_options.tunable_op_tuning_enable = info.tunable_op.tuning_enable;
rocm_options.tunable_op_max_tuning_duration_ms = info.tunable_op.max_tuning_duration_ms;
Expand Down
Loading

0 comments on commit 24f7d47

Please sign in to comment.