-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable QNN EP weight sharing generation using public API #23702
base: main
Are you sure you want to change the base?
Changes from 12 commits
69b662d
9ec8b1f
906cfd5
eb6104b
eaf7041
3800cbc
1a135ea
c8cfc2e
c8eb395
0b3dbed
ff59f37
de47a5c
dea89d7
aec117c
92d8423
57783ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,24 +3,19 @@ | |
|
||
// onnxruntime dependencies | ||
#include "test_configuration.h" | ||
#include <core/session/onnxruntime_c_api.h> | ||
#include <core/session/onnxruntime_cxx_api.h> | ||
#include <random> | ||
#include "command_args_parser.h" | ||
#include <google/protobuf/stubs/common.h> | ||
|
||
#include "core/session/onnxruntime_session_options_config_keys.h" | ||
#include "core/session/inference_session.h" | ||
#include "core/session/ort_env.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I believe this is still an internal header. And I also see that this tool still depends on the internal graph classes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Actually for that post processing part, user should be able to use Onnx API to update the Onnx model. That's not the main part we want to cover in this tool. But anyway, let me make the changes to use Onnx API to make it clear. |
||
#include "core/providers/provider_factory_creators.h" | ||
#include "core/common/logging/sinks/clog_sink.h" | ||
|
||
#include "core/graph/model.h" | ||
#include "core/session/environment.h" | ||
#include "core/common/logging/logging.h" | ||
|
||
using namespace onnxruntime; | ||
const OrtApi* g_ort = NULL; | ||
using ProviderOptions = std::unordered_map<std::string, std::string>; | ||
|
||
std::unique_ptr<Ort::Env> ort_env; | ||
|
||
static void CheckStatus(const Status& status) { | ||
|
@@ -110,52 +105,29 @@ int real_main(int argc, wchar_t* argv[]) { | |
#else | ||
int real_main(int argc, char* argv[]) { | ||
#endif | ||
g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); | ||
qnnctxgen::TestConfig test_config; | ||
if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { | ||
qnnctxgen::CommandLineParser::ShowUsage(); | ||
return -1; | ||
} | ||
|
||
{ | ||
bool failed = false; | ||
ORT_TRY { | ||
OrtLoggingLevel logging_level = test_config.run_config.f_verbose | ||
? ORT_LOGGING_LEVEL_VERBOSE | ||
: ORT_LOGGING_LEVEL_WARNING; | ||
|
||
ort_env = std::make_unique<Ort::Env>(logging_level, "Default"); | ||
} | ||
ORT_CATCH(const Ort::Exception& e) { | ||
ORT_HANDLE_EXCEPTION([&]() { | ||
fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); | ||
failed = true; | ||
}); | ||
} | ||
|
||
if (failed) | ||
return -1; | ||
} | ||
OrtLoggingLevel logging_level = test_config.run_config.f_verbose | ||
? ORT_LOGGING_LEVEL_VERBOSE | ||
: ORT_LOGGING_LEVEL_WARNING; | ||
Ort::Env env(logging_level, "ep_weight_sharing"); | ||
|
||
ORT_TRY { | ||
SessionOptions so; | ||
so.session_logid = "qnn_ctx_gen_session_logger"; | ||
// Set default session option to dump QNN context model with non-embed mode | ||
CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); | ||
CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); | ||
RunOptions run_options; | ||
run_options.run_tag = so.session_logid; | ||
Ort::SessionOptions so; | ||
so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); | ||
// Set default session option to dump EPContext model with non-embed mode | ||
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); | ||
so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); | ||
// enable ep.share_ep_contexts | ||
so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); | ||
|
||
ProviderOptions provider_options; | ||
#if defined(_WIN32) | ||
provider_options["backend_path"] = "QnnHtp.dll"; | ||
#else | ||
provider_options["backend_path"] = "libQnnHtp.so"; | ||
#endif | ||
// set default QNN EP option to enable weight sharing | ||
provider_options["enable_htp_weight_sharing"] = "1"; | ||
|
||
for (auto it : test_config.run_config.qnn_options) { | ||
for (auto it : test_config.run_config.provider_options) { | ||
provider_options[it.first] = it.second; | ||
} | ||
|
||
|
@@ -172,7 +144,7 @@ int real_main(int argc, char* argv[]) { | |
std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; | ||
continue; | ||
} | ||
CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); | ||
so.AddConfigEntry(it.first.c_str(), it.second.c_str()); | ||
} | ||
|
||
for (auto model_path : test_config.model_file_paths) { | ||
|
@@ -182,15 +154,30 @@ int real_main(int argc, char* argv[]) { | |
// Generate context cache model files with QNN context binary files | ||
// The context binary file generated later includes all graphs from previous models | ||
{ | ||
auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); | ||
std::shared_ptr<IExecutionProvider> qnn_ep(std::move(ep)); | ||
std::string provider_name_ = test_config.machine_config.provider_type_name; | ||
if (provider_name_ == onnxruntime::kQnnExecutionProvider) { | ||
#ifdef USE_QNN | ||
#if defined(_WIN32) | ||
provider_options["backend_path"] = "QnnHtp.dll"; | ||
#else | ||
provider_options["backend_path"] = "libQnnHtp.so"; | ||
#endif | ||
// set default QNN EP option to enable weight sharing if not set by user | ||
const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; | ||
if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { | ||
provider_options[enable_htp_weight_sharing] = "1"; | ||
} | ||
so.AppendExecutionProvider("QNN", provider_options); | ||
#else | ||
ORT_THROW("QNN is not supported in this build\n"); | ||
#endif | ||
} else if (!provider_name_.empty()) { | ||
ORT_THROW("This execution provider is not included in this tool.\n"); | ||
} | ||
|
||
for (auto model_path : test_config.model_file_paths) { | ||
std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; | ||
InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; | ||
CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); | ||
CheckStatus(session_object.Load(ToPathString(model_path))); | ||
CheckStatus(session_object.Initialize()); | ||
Ort::Session session(env, model_path.c_str(), so); | ||
} | ||
} | ||
|
||
|
@@ -213,7 +200,7 @@ int real_main(int argc, char* argv[]) { | |
|
||
// Update generated context cache Onnx model to make the main EPContext node point to | ||
// the last QNN context binary file | ||
// Remove not used QNN context binary file, only keep the last one which contains all graphs | ||
// Remove not used QNN context binary file, only keep the last one only which contains all graphs | ||
UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); | ||
} | ||
ORT_CATCH(const Ort::Exception& e) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can some of these internal libraries be removed now that the tool uses public APIs?