Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Add new provider option to exclude specific ops from running on TRT #22892

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct OrtTensorRTProviderOptionsV2 {
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_op_types_to_exclude{nullptr}; // Exclude specific ops from running on TRT.
};
40 changes: 28 additions & 12 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
engine_hw_compatible_ = info.engine_hw_compatible;
op_types_to_exclude_ = info.op_types_to_exclude;

} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1565,6 +1567,11 @@
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude);
if (!op_types_to_exclude_env.empty()) {
op_types_to_exclude_ = op_types_to_exclude_env;
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1766,7 +1773,8 @@
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
<< ", trt_op_types_to_exclude: " << op_types_to_exclude_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2434,6 +2442,18 @@
return cycle_detected;
}

std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
std::set<std::string> set;
if (!node_list_to_exclude.empty()) {
std::stringstream node_list(node_list_to_exclude);
std::string node;
while (std::getline(node_list, node, ',')) {
set.insert(node);
}
}
return set;
}

std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
Expand Down Expand Up @@ -2466,17 +2486,13 @@
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::set<std::string> exclude_ops_set;
std::set<std::string> exclude_ops_set = GetExcludedNodeSet(op_types_to_exclude_);

/*
* There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10.
* TRT EP automatically excludes DDS ops from running on TRT.
*/
if (trt_version_ >= 100000 && trt_version_ < 110000) {
exclude_ops_set.insert("NonMaxSuppression");
exclude_ops_set.insert("NonZero");
exclude_ops_set.insert("RoiAlign");
LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable";
// Print excluded nodes, if any.
std::set<std::string>::iterator it;

Check warning on line 2492 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2492: Add #include <set> for set<> [build/include_what_you_use] [4]
for (it = exclude_ops_set.begin(); it != exclude_ops_set.end(); ++it) {
std::string op = *it;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any.";
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
Expand All @@ -2485,7 +2501,7 @@

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's a DDS op.
* 2. It's in the exlucded ops set which specified by trt_op_types_to_exclude.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE";

Check warning on line 60 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 For a static/global string constant, use a C style string instead: "static const char kOpTypesToExclude[]". [runtime/string] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:60: For a static/global string constant, use a C style string instead: "static const char kOpTypesToExclude[]". [runtime/string] [4]
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude";

} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
Expand Down Expand Up @@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)},
};
return options;
}
Expand All @@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude);

const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
Expand Down Expand Up @@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_},
};
return options;
}
Expand Down Expand Up @@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude);
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
std::string op_types_to_exclude{""};

Check warning on line 63 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h:63: Add #include <string> for string [build/include_what_you_use] [4]

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider {
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
info.onnx_bytestream = options.trt_onnx_bytestream;
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude;

return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
delete[] ptr->trt_onnx_model_folder_path;
delete[] ptr->trt_op_types_to_exclude;
}

std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path;
onnx_model_folder_path, trt_op_types_to_exclude;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
Expand Down Expand Up @@ -824,6 +824,13 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_op_types_to_exclude") {
if (!option.second.empty()) {
trt_op_types_to_exclude = option.second;
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_op_types_to_exclude' should be a string.\n");
}
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
60 changes: 60 additions & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,66 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) {
/* The mnist.onnx looks like this:
* Conv
* |
* Add
* .
* .
* |
* MaxPool
* |
* .
* .
* MaxPool
* |
* Reshape
* |
* MatMul
* .
* .
*
*/
PathString model_name = ORT_TSTR("testdata/mnist.onnx");
SessionOptions so;
so.session_logid = "TensorrtExecutionProviderExcludeOpsTest";
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
auto cuda_provider = DefaultCudaExecutionProvider();
auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1];
std::vector<int64_t> dims_op_x = {1, 1, 28, 28};
std::vector<float> values_op_x(784, 1.0f); // 784=1*1*28*28
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_x);
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value_x));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Plus214_Output_0");
std::vector<OrtValue> fetches;

RemoveCachesByType("./", ".engine");
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
params.trt_op_types_to_exclude = "MaxPool";
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());

std::vector<fs::path> engine_files;
engine_files = GetCachesByType("./", ".engine");
// The whole graph should be partitioned into 3 TRT subgraphs and 2 cpu nodes
ASSERT_EQ(engine_files.size(), 3);
}

TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
PathString model_name = ORT_TSTR("testdata/trt_plugin_custom_op_test.onnx");
SessionOptions so;
Expand Down
Loading