From 8fdf6a7adb3388841ba44a63b6b19e9fdf45ac6f Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 10 Nov 2023 13:08:24 -0500 Subject: [PATCH] [Migraphx EP] Static int8 QDQ support (#17931) (#23) ### Description Adding static int8 quantization support for MIGraphX Execution Provider - Allows for parsing in calibration tables generated by Onnxruntime or TensorRT's toolsets - Add proper environment variables into the MIGraphX EP - Update python API to include updating execution provider flags -> was missing on python side - Hook into MIGraphX's int8 quantitation and optimization of models ### Motivation and Context Required so that we can get onnxruntime to pass in models while leveraging the existing tooling for int8 static QDQ quantization. First step in a series of PRs which will add further static quantization on the operator level as MIGraphX releases further support. These changes drew heavily from the tensorRT EP should allow for similar functionality for GPU based (versus CPU) quantization of models before an inference is performed. --------- Co-authored-by: Ted Themistokleous Co-authored-by: Ted Themistokleous --- .../core/session/onnxruntime_c_api.h | 8 +- .../core/providers/migraphx/migraphx_call.cc | 4 +- .../migraphx/migraphx_execution_provider.cc | 281 ++++++++++++++++-- .../migraphx/migraphx_execution_provider.h | 27 +- .../migraphx_execution_provider_info.cc | 11 +- .../migraphx_execution_provider_info.h | 3 + .../migraphx_execution_provider_utils.h | 115 ++++++- .../migraphx/migraphx_provider_factory.cc | 32 +- .../migraphx/ort_trt_int8_cal_table.fbs.h | 145 +++++++++ .../python/onnxruntime_pybind_state.cc | 100 ++++++- .../python/tools/transformers/benchmark.py | 2 +- onnxruntime/test/util/default_providers.cc | 4 +- 12 files changed, 671 insertions(+), 61 deletions(-) create mode 100644 onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e483c67a0cfe6..a4d5bd12be7f7 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -598,9 +598,11 @@ typedef struct OrtTensorRTProviderOptions { * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX */ typedef struct OrtMIGraphXProviderOptions { - int device_id; // hip device id. - int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true - int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true + int device_id; // hip device id. + int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index cd947420b7615..5248ac2f39214 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include #include #include #include -#include "migraphx_call.h" #include "core/common/common.h" #include "core/common/status.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index d2538544db60e..d1b3f19100942 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,5 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -12,10 +17,6 @@ #include "gpu_data_transfer.h" #include "migraphx_inc.h" -#include -#include -#include - // TODO: find a better way to share this #include "core/providers/rocm/rocm_stream_handle.h" @@ -113,6 +114,45 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether int8 is enabled + const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); + if (!int8_enable_env.empty()) { + int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + } + + if (int8_enable_) { + const std::string int8_calibration_cache_name_env = + onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); + if (!int8_calibration_cache_name_env.empty()) { + int8_calibration_cache_name_ = int8_calibration_cache_name_env; + } + + const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); + if (!cache_path.empty()) { + calibration_cache_path_ = cache_path; + } + + const std::string int8_use_native_migraphx_calibration_table_env = + onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); + if (!int8_use_native_migraphx_calibration_table_env.empty()) { + int8_use_native_migraphx_calibration_table_ = + (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + } + } + + if (int8_enable_) { + int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { @@ -124,6 +164,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); + + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " + << "device_id: " << device_id_ + << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_int8_enable: " << int8_enable_ + << ", dump_model_ops: " << dump_model_ops_ + << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ + << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_; } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { @@ -467,7 +516,8 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return false; } -void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, const logging::Logger& logger) { +void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, + const logging::Logger& logger) { // Then check whether a subgraph should fallback to CPU // 1. Check whether a subgraph contains a RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; @@ -642,7 +692,8 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st fused_inputs.erase(iter); erased.insert(output); } else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -660,7 +711,8 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } // Only when output is neither in input list nor erased list, add the output to output list else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -733,31 +785,156 @@ static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { - static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", - "ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "ATen", "AveragePool", - "BatchNormalization", "Cast", "Ceil", "Celu", "Clip", "Concat", "Constant", "ConstantFill", - "ConstantOfShape", "Conv", "ConvInteger", "ConvTranspose", "Cos", "Cosh", "CumSum", - "DepthToSpace", "DequantizeLinear", "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", - "Expand", "EyeLike", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "GatherND", "Gemm", "GlobalAveragePool", - "GlobalMaxPool", "Greater", "GreaterOrEqual", "HardSigmoid", "HardSwish", "Identity", - "If", "ImageScaler", "InstanceNormalization", "IsNan", "LeakyRelu", "Less", "LessOrEqual", - "Log", "LogSoftmax", "Loop", "LpNormalization", "LRN", "LSTM", "MatMul", "MatMulInteger", "Max", "MaxPool", - "Mean", "Min", "Mod", "Mul", "Multinomial", "Neg", "NonMaxSuppression", "NonZero", "Not", - "OneHot", "Or", "Pad", "Pow", "PRelu", "QuantizeLinear", "RandomNormal", "RandomNormalLike", - "RandomUniform", "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", - "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", - "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", "Resize", "ReverseSequence", "RNN", "Roialign", "Round", - "Scatter", "ScatterElements", "ScatterND", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Softplus", - "Softsign", "SpaceToDepth", "Split", "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", - "ThresholdedRelu", "Tile", "TopK", "Transpose", "Trilu", "Unsqueeze", "Upsample", "Where", "Xor"}; + static std::set mgx_supported_ops = {"Abs", + "Acos", + "Acosh", + "Add", + "And", + "ArgMax", + "ArgMin", + "Asin", + "Asinh", + "Atan", + "Atanh", + "ATen", + "AveragePool", + "BatchNormalization", + "Cast", + "Ceil", + "Celu", + "Clip", + "Concat", + "Constant", + "ConstantFill", + "ConstantOfShape", + "Conv", + "ConvInteger", + "ConvTranspose", + "Cos", + "Cosh", + "CumSum", + "DepthToSpace", + "DequantizeLinear", + "Div", + "Dropout", + "Elu", + "Equal", + "Erf", + "Exp", + "Expand", + "EyeLike", + "Flatten", + "Floor", + "GRU", + "Gather", + "GatherElements", + "GatherND", + "Gemm", + "GlobalAveragePool", + "GlobalMaxPool", + "Greater", + "GreaterOrEqual", + "HardSigmoid", + "HardSwish", + "Identity", + "If", + "ImageScaler", + "InstanceNormalization", + "IsNan", + "LeakyRelu", + "Less", + "LessOrEqual", + "Log", + "LogSoftmax", + "Loop", + "LpNormalization", + "LRN", + "LSTM", + "MatMul", + "MatMulInteger", + "Max", + "MaxPool", + "Mean", + "Min", + "Mod", + "Mul", + "Multinomial", + "Neg", + "NonMaxSuppression", + "NonZero", + "Not", + "OneHot", + "Or", + "Pad", + "Pow", + "PRelu", + "QLinearAdd", + "QLinearConv", + "QLinearMatMul", + "QuantizeLinear", + "RandomNormal", + "RandomNormalLike", + "RandomUniform", + "RandomUniformLike", + "Range", + "Reciprocal", + "ReduceL1", + "ReduceL2", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + "ReduceSumSquare", + "Relu", + "Reshape", + "Resize", + "ReverseSequence", + "RNN", + "Roialign", + "Round", + "Scatter", + "ScatterElements", + "ScatterND", + "Selu", + "Shape", + "Sigmoid", + "Sign", + "Sin", + "Sinh", + "Slice", + "Softmax", + "Softplus", + "Softsign", + "SpaceToDepth", + "Split", + "Sqrt", + "Squeeze", + "Sub", + "Sum", + "Tan", + "Tanh", + "ThresholdedRelu", + "Tile", + "TopK", + "Transpose", + "Trilu", + "Unsqueeze", + "Upsample", + "Where", + "Xor"}; std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { // Collect inputs that are initializers - graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { + graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, + &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) { mgx_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -770,7 +947,8 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, // is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph). // This functions returns vector of all supported_subgraphx by amdmigraphx static std::vector> -GetPartitionedSubgraphs(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedSubgraphs(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> mgx_subgraphx; auto prev = topological_order.begin(); @@ -948,6 +1126,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_fp16(prog); } + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable_ && int8_calibration_cache_available_) { + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + + auto param_shapes = prog.get_parameter_shapes(); + + for (auto&& name : param_shapes.names()) { + auto dynamic_range_i = dynamic_range_map.find(name); + if (dynamic_range_i != dynamic_range_map.end()) { + quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + } + } + + quant_opts.add_calibration_data(quant_params); + // perform static quantization on the programs + migraphx::quantize_int8(prog, t_, quant_opts); + } prog.compile(t_); auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { @@ -967,7 +1163,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, dump_model_ops_}; + map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; *state = p.release(); return 0; }; @@ -982,12 +1179,15 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& MIGraphXFuncState* mgx_state = reinterpret_cast(state); std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; + std::unordered_map& map_dynamic_range = mgx_state->dynamic_range_map; migraphx::target t = mgx_state->t; migraphx::program& prog = mgx_state->prog; std::string& onnx_string = mgx_state->onnx_string; migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool int8_enable = mgx_state->int8_enable; + bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; // mean no program at all, so need to get the input shape info // from input data @@ -1043,6 +1243,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_fp16(prog); } + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable && int8_calibration_cache_available) { + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + + auto param_shapes = prog.get_parameter_shapes(); + + for (auto&& name : param_shapes.names()) { + auto dynamic_range_i = map_dynamic_range.find(name); + if (dynamic_range_i != map_dynamic_range.end()) { + quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + } + } + + quant_opts.add_calibration_data(quant_params); + // perform static quantization on the programs + migraphx::quantize_int8(prog, t, quant_opts); + } + prog.compile(t); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); @@ -1137,9 +1356,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } -void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { +void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, + AllocatorMap& allocators) const { auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); + RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, + false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 1f591f9a1c0a5..c094be51012e4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,23 +3,29 @@ #pragma once +#include +#include + #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" -#include "migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_inc.h" #include -#include "migraphx_inc.h" +#include // TODO: find a better way to share this // #include "core/providers/cuda/rocm_stream_handle.h" -#include -#include namespace onnxruntime { namespace migraphx_env_vars { -static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"; -static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; +static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; +static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; +static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; +static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; +static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -35,6 +41,9 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool int8_enable = false; + bool int8_calibration_cache_available = false; + std::unordered_map dynamic_range_map; bool dump_model_ops = false; }; @@ -69,6 +78,12 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: bool fp16_enable_ = false; + bool int8_enable_ = false; + std::string int8_calibration_cache_name_; + bool int8_calibration_cache_available_ = false; + bool int8_use_native_migraphx_calibration_table_ = false; + std::string calibration_cache_path_; + std::unordered_map dynamic_range_map; bool dump_model_ops_ = false; int device_id_; migraphx::target t_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index bdf8388e75c15..b7d7a77853df6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,7 +14,10 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8Enable = "migx_int8_enable"; +constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; + } // namespace provider_option_names } // namespace migraphx @@ -45,7 +48,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}}; + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + }; return options; } @@ -53,7 +57,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}}; + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + }; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 472d418c9099c..18ac30fdc1283 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/framework/ortdevice.h" #include "core/framework/provider_options.h" @@ -16,6 +17,8 @@ struct MIGraphXExecutionProviderInfo { int device_id{0}; bool fp16_enable{false}; bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index fb0be15986111..071070e92a209 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -2,8 +2,20 @@ // Licensed under the MIT License #pragma once + +#include +#include +#include +#include +#include +#include +#include "flatbuffers/idl.h" +#include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" +#include "core/common/path_string.h" + +namespace fs = std::filesystem; namespace onnxruntime { @@ -101,7 +113,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector indices, std::vector& input_nodes) { +bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -137,4 +152,102 @@ bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector return true; } +float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { + int s = (input >> 31) & 0x01; + int e = ((input & 0x7f800000) >> 23) - 127; + int p = -1; + double m = 0.0; + for (int i = 0; i < 23; ++i) { + m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--); + } + return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); +} + +/* + * Read calibration table for INT8 quantization + * Two kind of calibration tables are supported, + * 1. ORT generated calibration table + * The table is pre-serialized by flatbuffers. + * Each entry in the table is a key-value pair, + * key: tensor name, value: maximum absolute value in floating point + * For example, + * data_0 2.008338 + * ... + * 2. Native TensorRT generated calibration table + * Data format is defined by TensorRT as, + * tensor name : scale in 32-bit single precision IEEE754 format + * For example, + * TRT-7103-EntropyCalibration2 + * data_0: 4000889d + * ... + * + * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models + * + */ +bool ReadDynamicRange(const std::string file_name, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { + std::ifstream infile(file_name, std::ios::binary | std::ios::in); + if (!infile) { + return false; + } + + if (is_calibration_table) { + // Native TensorRT generated calibration table + std::string line; + char delim = ':'; + if (std::getline(infile, line)) { + std::istringstream first_line(line); + std::string version; + std::getline(first_line, version, delim); + std::size_t found = version.find("TRT-"); + if (found != std::string::npos) { + while (std::getline(infile, line)) { + std::istringstream in_line(line); + std::string str; + std::getline(in_line, str, delim); + std::string tensor_name = str; + std::getline(in_line, str, delim); + uint32_t scale_int = std::strtoul(str.c_str(), nullptr, 16); + float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int); + float dynamic_range = scale_float * 127.0f; + dynamic_range_map[tensor_name] = dynamic_range; + } + } else { + throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + } + } + } else { + // ORT generated calibration table + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read(reinterpret_cast(data.get()), length); + infile.close(); + auto flat_table = flatbuffers::GetRoot(reinterpret_cast(data.get())); + auto flat_dict = flat_table->dict(); + for (size_t i = 0, end = flat_dict->size(); i < end; ++i) { + flatbuffers::uoffset_t idx = static_cast(i); + dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str()); + } + } + return true; +} + +/* + * Get cache by name + * + */ +std::string GetCachePath(const std::string& root, const std::string& name) { + if (root.empty()) { + return name; + } else { + fs::path path = root; + path.append(name); + return path.string(); + } +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 8358ca5fcda95..f985682ddc735 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" @@ -8,7 +9,6 @@ #include "hip_allocator.h" #include "gpu_data_transfer.h" #include "core/framework/provider_options.h" -#include #include "core/session/onnxruntime_c_api.h" @@ -48,15 +48,37 @@ struct MIGraphX_Provider : Provider { info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; info.int8_enable = options.migraphx_int8_enable; + info.int8_calibration_table_name = ""; + if (options.migraphx_int8_calibration_table_name != nullptr) { + info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; + } + info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; return std::make_shared(info); } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& trt_options = *reinterpret_cast(provider_options); - trt_options.device_id = internal_options.device_id; - trt_options.migraphx_fp16_enable = internal_options.fp16_enable; - trt_options.migraphx_int8_enable = internal_options.int8_enable; + auto& migx_options = *reinterpret_cast(provider_options); + migx_options.device_id = internal_options.device_id; + migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_int8_enable = internal_options.int8_enable; + + char* dest = nullptr; + auto str_size = internal_options.int8_calibration_table_name.size(); + if (str_size == 0) { + migx_options.migraphx_int8_calibration_table_name = nullptr; + } else { + dest = new char[str_size + 1]; +#ifdef _MSC_VER + strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); +#else + strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); +#endif + dest[str_size] = '\0'; + migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + } + + migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h new file mode 100644 index 0000000000000..9639040f772da --- /dev/null +++ b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h @@ -0,0 +1,145 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ +#define ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ + +#include +#include "flatbuffers/flatbuffers.h" + +namespace CalTableFlatBuffers { + +struct KeyValue; +struct KeyValueBuilder; + +struct TrtTable; +struct TrtTableBuilder; + +struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KeyValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String* key() const { + return GetPointer(VT_KEY); + } + bool KeyCompareLessThan(const KeyValue* o) const { + return *key() < *o->key(); + } + int KeyCompareWithValue(const char* val) const { + return strcmp(key()->c_str(), val); + } + const flatbuffers::String* value() const { + return GetPointer(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KEY) && + verifier.VerifyString(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyString(value()) && + verifier.EndTable(); + } +}; + +struct KeyValueBuilder { + typedef KeyValue Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(KeyValue::VT_KEY, key); + } + void add_value(flatbuffers::Offset value) { + fbb_.AddOffset(KeyValue::VT_VALUE, value); + } + explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KeyValueBuilder& operator=(const KeyValueBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KeyValue::VT_KEY); + return o; + } +}; + +inline flatbuffers::Offset CreateKeyValue( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset key = 0, + flatbuffers::Offset value = 0) { + KeyValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKeyValueDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const char* key = nullptr, + const char* value = nullptr) { + auto key__ = key ? _fbb.CreateString(key) : 0; + auto value__ = value ? _fbb.CreateString(value) : 0; + return CalTableFlatBuffers::CreateKeyValue( + _fbb, + key__, + value__); +} + +struct TrtTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TrtTableBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DICT = 4 + }; + const flatbuffers::Vector>* dict() const { + return GetPointer>*>(VT_DICT); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DICT) && + verifier.VerifyVector(dict()) && + verifier.VerifyVectorOfTables(dict()) && + verifier.EndTable(); + } +}; + +struct TrtTableBuilder { + typedef TrtTable Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_dict(flatbuffers::Offset>> dict) { + fbb_.AddOffset(TrtTable::VT_DICT, dict); + } + explicit TrtTableBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TrtTableBuilder& operator=(const TrtTableBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTrtTable( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset>> dict = 0) { + TrtTableBuilder builder_(_fbb); + builder_.add_dict(dict); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTrtTableDirect( + flatbuffers::FlatBufferBuilder& _fbb, + std::vector>* dict = nullptr) { + auto dict__ = dict ? _fbb.CreateVectorOfSortedTables(dict) : 0; + return CalTableFlatBuffers::CreateTrtTable( + _fbb, + dict__); +} + +} // namespace CalTableFlatBuffers + +#endif // ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 907ea0ec41e23..8423dcfbadc58 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -718,33 +718,115 @@ std::unique_ptr CreateExecutionProviderInstance( } } } - LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please reference https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met."; + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/" + << "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met."; #endif } else if (type == kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX - return onnxruntime::MIGraphXProviderFactoryCreator::Create(0)->CreateProvider(); + std::string calibration_table; + auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + OrtMIGraphXProviderOptions params{ + 0, + 0, + 0, + 0, + nullptr}; + for (auto option : it->second) { + if (option.first == "device_id") { + if (!option.second.empty()) { + params.device_id = std::stoi(option.second); + } else { + ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); + } + } else if (option.first == "migraphx_fp16_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_int8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_int8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_int8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_int8_calibration_table_name") { + if (!option.second.empty()) { + calibration_table = option.second; + params.migraphx_int8_calibration_table_name = calibration_table.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a " + "file name i.e. 'cal_table'.\n"); + } + } else if (option.first == "migraphx_use_native_calibration_table") { + if (option.second == "True" || option.second == "true") { + params.migraphx_use_native_calibration_table = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_use_native_calibration_table = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else { + ORT_THROW("Invalid MIGraphX EP option: ", option.first); + } + } + if (std::shared_ptr migraphx_provider_factory = + onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + return migraphx_provider_factory->CreateProvider(); + } + } else { + if (std::shared_ptr migraphx_provider_factory = + onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) { + return migraphx_provider_factory->CreateProvider(); + } + } #endif } else if (type == kCudaExecutionProvider) { #ifdef USE_CUDA - // If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda. This is set by _ld_preload for the manylinux case - // as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. + // If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda. + // This is set by _ld_preload for the manylinux case as in that case, + // trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. if (Env::Default().GetEnvironmentVar("ORT_CUDA_UNAVAILABLE").empty()) { if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) { const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info, provider_options_map); - // This variable is never initialized because the APIs by which it should be initialized are deprecated, however they still - // exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can - // since FromProviderOptions might contain external CUDA allocator. + // This variable is never initialized because the APIs by which it should be initialized are deprecated, + // however they still exist are are in-use. Neverthless, it is used to return CUDAAllocator, + // hence we must try to initialize it here if we can since FromProviderOptions might contain + // external CUDA allocator. external_allocator_info = info.external_allocator_info; return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); } else { if (!Env::Default().GetEnvironmentVar("CUDA_PATH").empty()) { - ORT_THROW("CUDA_PATH is set but CUDA wasn't able to be loaded. Please install the correct version of CUDA and cuDNN as mentioned in the GPU requirements page (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), make sure they're in the PATH, and that your GPU is supported."); + ORT_THROW( + "CUDA_PATH is set but CUDA wasnt able to be loaded. Please install the correct version of CUDA and" + "cuDNN as mentioned in the GPU requirements page " + " (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), " + " make sure they're in the PATH, and that your GPU is supported."); } } } - LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please reference https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met."; + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements" + << "to ensure all dependencies are met."; #endif } else if (type == kRocmExecutionProvider) { #ifdef USE_ROCM diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 97330295e17ed..f506516442b1e 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -779,7 +779,7 @@ def main(): logger.error("fp16 is for GPU only") return - if args.precision == Precision.INT8 and args.use_gpu: + if args.precision == Precision.INT8 and args.use_gpu and args.provider != "migraphx": logger.error("int8 is for CPU only") return diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 28af61e15b2b5..b47368c43cca4 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -69,7 +69,9 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { OrtMIGraphXProviderOptions params{ 0, 0, - 0}; + 0, + 0, + nullptr}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;