Skip to content

Commit 3819961

Browse files
author
Ted Themistokleous
committed
Add ORT_MIGRAPHX_SET_FAST_MATH env option and api hooks
Allow users to set the fast math option for MIGraphX compilation for quantized data types (fp16) This allows us to toggle whether we can use faster math with the tradeoff of accuracy.
1 parent a352c01 commit 3819961

File tree

6 files changed

+28
-3
lines changed

6 files changed

+28
-3
lines changed

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
114114
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
115115
}
116116

117+
// whether fp16 is enable
118+
const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization);
119+
if (!fast_math_env.empty()) {
120+
fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true);
121+
}
122+
117123
// whether int8 is enabled
118124
const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable);
119125
if (!int8_enable_env.empty()) {
@@ -168,6 +174,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
168174
LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
169175
<< "device_id: " << device_id_
170176
<< ", migraphx_fp16_enable: " << fp16_enable_
177+
<< ", migraphx_fast_math: " << fast_math_enable_
171178
<< ", migraphx_int8_enable: " << int8_enable_
172179
<< ", dump_model_ops: " << dump_model_ops_
173180
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
@@ -1145,7 +1152,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11451152
migraphx::quantize_int8(prog, t_, quant_opts);
11461153
}
11471154
migraphx::compile_options co;
1148-
co.set_fast_math(false);
1155+
co.set_fast_math(fast_math_enable_);
11491156
prog.compile(t_, co);
11501157
auto prog_output_shapes = prog.get_output_shapes();
11511158
for (std::size_t i = 0; i < output_names.size(); ++i) {
@@ -1165,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
11651172
std::unique_ptr<MIGraphXFuncState> p = std::make_unique<MIGraphXFuncState>();
11661173
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
11671174
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
1168-
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
1175+
map_no_input_shape_[context->node_name], fp16_enable_, fast_math_enable_, int8_enable_,
11691176
int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_};
11701177
*state = p.release();
11711178
return 0;
@@ -1265,7 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
12651272
}
12661273

12671274
migraphx::compile_options co;
1268-
co.set_fast_math(false);
1275+
co.set_fast_math(fast_math_enable);
12691276
prog.compile(t, co);
12701277
mgx_state->prog = prog;
12711278
param_shapes = prog.get_parameter_shapes();

onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
2626
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
2727
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
2828
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
29+
static const char kSetFastMathOptimization[] = "ORT_MIGRAPHX_SET_FAST_MATH";
2930
}; // namespace migraphx_env_vars
3031

3132
// Information to construct kernel function state.
@@ -41,6 +42,7 @@ struct MIGraphXFuncState {
4142
OrtMutex* mgx_mu_ptr = nullptr;
4243
bool no_input_shape = false;
4344
bool fp16_enable = false;
45+
bool fast_math_enable = false;
4446
bool int8_enable = false;
4547
bool int8_calibration_cache_available = false;
4648
std::unordered_map<std::string, float> dynamic_range_map;
@@ -78,6 +80,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
7880

7981
private:
8082
bool fp16_enable_ = false;
83+
bool fast_math_enable_ = false;
8184
bool int8_enable_ = false;
8285
std::string int8_calibration_cache_name_;
8386
bool int8_calibration_cache_available_ = false;

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace migraphx {
1414
namespace provider_option_names {
1515
constexpr const char* kDeviceId = "device_id";
1616
constexpr const char* kFp16Enable = "trt_fp16_enable";
17+
constexpr const char* kFastMathEnable = "migx_fast_math_enable";
1718
constexpr const char* kInt8Enable = "migx_int8_enable";
1819
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
1920
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
@@ -38,6 +39,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
3839
return Status::OK();
3940
})
4041
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
42+
.AddAssignmentToReference(migraphx::provider_option_names::kFastMathEnable, info.fast_math_enable)
4143
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
4244
.Parse(options));
4345

@@ -48,6 +50,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
4850
const ProviderOptions options{
4951
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5052
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
53+
{migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.fast_math_enable)},
5154
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
5255
};
5356
return options;
@@ -57,6 +60,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
5760
const ProviderOptions options{
5861
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
5962
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
63+
{migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.migraphx_fast_math_enable)},
6064
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
6165
};
6266
return options;

onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo {
1616
std::string target_device;
1717
int device_id{0};
1818
bool fp16_enable{false};
19+
bool fast_math_enable{false};
1920
bool int8_enable{false};
2021
std::string int8_calibration_table_name{""};
2122
bool int8_use_native_calibration_table{false};

onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct MIGraphX_Provider : Provider {
4747
info.device_id = options.device_id;
4848
info.target_device = "gpu";
4949
info.fp16_enable = options.migraphx_fp16_enable;
50+
info.fast_math_enable = options.migraphx_fast_math_enable;
5051
info.int8_enable = options.migraphx_int8_enable;
5152
info.int8_calibration_table_name = "";
5253
if (options.migraphx_int8_calibration_table_name != nullptr) {
@@ -61,6 +62,7 @@ struct MIGraphX_Provider : Provider {
6162
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
6263
migx_options.device_id = internal_options.device_id;
6364
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
65+
migx_options.migraphx_fast_math_enable = internal_options.fast_math_enable;
6466
migx_options.migraphx_int8_enable = internal_options.int8_enable;
6567

6668
char* dest = nullptr;

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
734734
0,
735735
0,
736736
0,
737+
0,
737738
nullptr};
738739
for (auto option : it->second) {
739740
if (option.first == "device_id") {
@@ -752,6 +753,13 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
752753
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
753754
" 'True' or 'False'. Default value is 'False'.\n");
754755
}
756+
}
757+
else if (option.first == "migraphx_set_fast_math") {
758+
if (option.second == "True" || option.second == "true") {
759+
params.migraphx_fast_math_enable = true;
760+
} else {
761+
params.migraphx_fast_math_enable = false;
762+
}
755763
} else if (option.first == "migraphx_int8_enable") {
756764
if (option.second == "True" || option.second == "true") {
757765
params.migraphx_int8_enable = true;

0 commit comments

Comments
 (0)