Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable QNN EP weight sharing generation using public API #23702

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN)
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:onnxruntime_qnn_ctx_gen>
$<TARGET_FILE:ep_weight_sharing_ctx_gen>
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
)
if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf")
Expand Down
24 changes: 12 additions & 12 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1282,31 +1282,31 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)

if(onnxruntime_USE_QNN)
#qnn ctx generator
set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen)
set(onnxruntime_qnn_ctx_gen_src_patterns
"${onnxruntime_qnn_ctx_gen_src_dir}/*.cc"
"${onnxruntime_qnn_ctx_gen_src_dir}/*.h")
set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen)
set(ep_weight_sharing_ctx_gen_src_patterns
"${ep_weight_sharing_ctx_gen_src_dir}/*.cc"
"${ep_weight_sharing_ctx_gen_src_dir}/*.h")

file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS
${onnxruntime_qnn_ctx_gen_src_patterns}
file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS
${ep_weight_sharing_ctx_gen_src_patterns}
)
onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src})
target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT}
onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src})
target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT}
${onnxruntime_graph_header} ${onnxruntime_exec_src_dir}
${CMAKE_CURRENT_BINARY_DIR})
if (WIN32)
target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings})
target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings})
if (NOT DEFINED SYS_PATH_LIB)
set(SYS_PATH_LIB shlwapi)
endif()
endif()

if(WIN32)
target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32)
target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32)
endif()
target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can some of these internal libraries be removed now that the tool uses public APIs?


set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest")
set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest")
endif()

# shared lib
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() {
QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC;
if (ProfilingLevel::BASIC == profiling_level_merge_) {
qnn_profile_level = QNN_PROFILE_LEVEL_BASIC;
LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic.";
} else if (ProfilingLevel::DETAILED == profiling_level_merge_) {
qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED;
LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed.";
}
Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_);
ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result));
Expand Down
37 changes: 26 additions & 11 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,26 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}
}

qnn_backend_manager_ = qnn::QnnBackendManager::Create(
qnn::QnnBackendManagerConfig{backend_path,
profiling_level_etw,
profiling_level,
profiling_file_path,
context_priority,
qnn_saver_path,
device_id_,
htp_arch,
soc_model,
enable_htp_weight_sharing});
// For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits
// So that all graphs from later sessions will be compiled into the same QNN context
if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) {
qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager();
qnn_backend_manager_shared = true;
// Only allow the share across 2 sessions in case out of control
SharedContext::GetInstance().ResetSharedQnnBackendManager();
} else {
qnn_backend_manager_ = qnn::QnnBackendManager::Create(
qnn::QnnBackendManagerConfig{backend_path,
profiling_level_etw,
profiling_level,
profiling_file_path,
context_priority,
qnn_saver_path,
device_id_,
htp_arch,
soc_model,
enable_htp_weight_sharing});
}

#if defined(_WIN32)
if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) {
Expand Down Expand Up @@ -1028,6 +1037,12 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
qnn_context_embed_mode_,
max_spill_fill_buffer_size,
logger));

if (share_ep_contexts_ && !qnn_backend_manager_shared &&
nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) {
ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_),
"Failed to set shared QnnBackendManager.");
}
}
return Status::OK();
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class QNNExecutionProvider : public IExecutionProvider {
std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>> qnn_models_;
bool context_cache_enabled_ = false;
// flag to know whether the shared QnnBackendManager is consumed. Stop passing it through if it is consumed.
bool qnn_backend_manager_shared = false;
std::string context_cache_path_cfg_ = "";
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/qnn/shared_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,39 @@ class SharedContext {
return graph_exist;
}

bool SetSharedQnnBackendManager(std::shared_ptr<qnn::QnnBackendManager>& qnn_backend_manager) {
const std::lock_guard<std::mutex> lock(mtx_);

if (qnn_backend_manager_ != nullptr) {
if (qnn_backend_manager_ == qnn_backend_manager) {
return true;
}
return false;
}
qnn_backend_manager_ = qnn_backend_manager;
return true;
}

std::shared_ptr<qnn::QnnBackendManager> GetSharedQnnBackendManager() {
const std::lock_guard<std::mutex> lock(mtx_);
return qnn_backend_manager_;
}

void ResetSharedQnnBackendManager() {
const std::lock_guard<std::mutex> lock(mtx_);
qnn_backend_manager_.reset();
}

private:
SharedContext() = default;
~SharedContext() = default;

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext);

// Used for passing through QNN models (deserialized from context binary) across sessions
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models_;
// Used for compiling multiple models into same QNN context binary
std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
// Producer sessions can be in parallel
// Consumer sessions have to be after producer sessions initialized
std::mutex mtx_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below:

`onnxruntime_qnn_ctx_gen [options...] model_path,model_path`
`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path`

./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx
./ep_weight_sharing_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx

Options:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ namespace qnnctxgen {

/*static*/ void CommandLineParser::ShowUsage() {
printf(
"onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n"
"Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n"
"ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n"
"Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n"
"Options:\n"
"\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. "
"Default:'qnn'.\n"
"\t-v: Show verbose information.\n"
"\t-C: Specify session configuration entries as key-value pairs: -C \"<key1>|<value1> <key2>|<value2>\" \n"
"\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n"
Expand All @@ -49,7 +51,7 @@ namespace qnnctxgen {
"\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n"
"\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n"
"\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n"
"\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n"
"\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n"
"\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary."
"\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n"
"\n"
Expand Down Expand Up @@ -109,8 +111,22 @@ static bool ParseSessionConfigs(const std::string& configs_string,

/*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
int ch;
while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) {
while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) {
switch (ch) {
case 'e':
if (!CompareCString(optarg, ORT_TSTR("qnn"))) {
test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("openvino"))) {
test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) {
test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider;
} else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) {
test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider;
} else {
fprintf(stderr, "The execution provider is not included in this tool.\n");
return false;
}
break;
case 'v':
test_config.run_config.f_verbose = true;
break;
Expand Down Expand Up @@ -162,7 +178,7 @@ static bool ParseSessionConfigs(const std::string& configs_string,
'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])");
}

test_config.run_config.qnn_options[key] = value;
test_config.run_config.provider_options[key] = value;
}
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,19 @@

// onnxruntime dependencies
#include "test_configuration.h"
#include <core/session/onnxruntime_c_api.h>
#include <core/session/onnxruntime_cxx_api.h>
#include <random>
#include "command_args_parser.h"
#include <google/protobuf/stubs/common.h>

#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I believe this is still an internal header. And I also see that this tool still depends on the internal graph classes onnxruntime::Graph and onnxruntime::Node, which have a public header but a private/internal implementation. Since this still requires the tool to be compiled with internal ORT code, would this prevent users from integrating this into their own toolchains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Actually for that post processing part, user should be able to use Onnx API to update the Onnx model. That's not the main part we want to cover in this tool. But anyway, let me make the changes to use Onnx API to make it clear.

#include "core/providers/provider_factory_creators.h"
#include "core/common/logging/sinks/clog_sink.h"

#include "core/graph/model.h"
#include "core/session/environment.h"
#include "core/common/logging/logging.h"

using namespace onnxruntime;
const OrtApi* g_ort = NULL;
using ProviderOptions = std::unordered_map<std::string, std::string>;

std::unique_ptr<Ort::Env> ort_env;

static void CheckStatus(const Status& status) {
Expand Down Expand Up @@ -110,52 +105,29 @@ int real_main(int argc, wchar_t* argv[]) {
#else
int real_main(int argc, char* argv[]) {
#endif
g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
qnnctxgen::TestConfig test_config;
if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) {
qnnctxgen::CommandLineParser::ShowUsage();
return -1;
}

{
bool failed = false;
ORT_TRY {
OrtLoggingLevel logging_level = test_config.run_config.f_verbose
? ORT_LOGGING_LEVEL_VERBOSE
: ORT_LOGGING_LEVEL_WARNING;

ort_env = std::make_unique<Ort::Env>(logging_level, "Default");
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
fprintf(stderr, "Error creating environment. Error-> %s \n", e.what());
failed = true;
});
}

if (failed)
return -1;
}
OrtLoggingLevel logging_level = test_config.run_config.f_verbose
? ORT_LOGGING_LEVEL_VERBOSE
: ORT_LOGGING_LEVEL_WARNING;
Ort::Env env(logging_level, "ep_weight_sharing");

ORT_TRY {
SessionOptions so;
so.session_logid = "qnn_ctx_gen_session_logger";
// Set default session option to dump QNN context model with non-embed mode
CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"));
CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"));
RunOptions run_options;
run_options.run_tag = so.session_logid;
Ort::SessionOptions so;
so.SetLogId("ep_weight_sharing_ctx_gen_session_logger");
// Set default session option to dump EPContext model with non-embed mode
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0");
// enable ep.share_ep_contexts
so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1");

ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
// set default QNN EP option to enable weight sharing
provider_options["enable_htp_weight_sharing"] = "1";

for (auto it : test_config.run_config.qnn_options) {
for (auto it : test_config.run_config.provider_options) {
provider_options[it.first] = it.second;
}

Expand All @@ -172,7 +144,7 @@ int real_main(int argc, char* argv[]) {
std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl;
continue;
}
CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str()));
so.AddConfigEntry(it.first.c_str(), it.second.c_str());
}

for (auto model_path : test_config.model_file_paths) {
Expand All @@ -182,15 +154,30 @@ int real_main(int argc, char* argv[]) {
// Generate context cache model files with QNN context binary files
// The context binary file generated later includes all graphs from previous models
{
auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider();
std::shared_ptr<IExecutionProvider> qnn_ep(std::move(ep));
std::string provider_name_ = test_config.machine_config.provider_type_name;
if (provider_name_ == onnxruntime::kQnnExecutionProvider) {
#ifdef USE_QNN
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
// set default QNN EP option to enable weight sharing if not set by user
const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing";
if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) {
provider_options[enable_htp_weight_sharing] = "1";
}
so.AppendExecutionProvider("QNN", provider_options);
#else
ORT_THROW("QNN is not supported in this build\n");
#endif
} else if (!provider_name_.empty()) {
ORT_THROW("This execution provider is not included in this tool.\n");
}

for (auto model_path : test_config.model_file_paths) {
std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl;
InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()};
CheckStatus(session_object.RegisterExecutionProvider(qnn_ep));
CheckStatus(session_object.Load(ToPathString(model_path)));
CheckStatus(session_object.Initialize());
Ort::Session session(env, model_path.c_str(), so);
}
}

Expand All @@ -213,7 +200,7 @@ int real_main(int argc, char* argv[]) {

// Update generated context cache Onnx model to make the main EPContext node point to
// the last QNN context binary file
// Remove not used QNN context binary file, only keep the last one which contains all graphs
// Remove not used QNN context binary file, only keep the last one only which contains all graphs
UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size);
}
ORT_CATCH(const Ort::Exception& e) {
Expand Down
Loading
Loading