diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index aee6d2ff7655c..64b53c2912be0 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - $ + $ $/onnxruntime/capi/ ) if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 896139a4e65a3..e5949d471c51c 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1282,31 +1282,29 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_USE_QNN) #qnn ctx generator - set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen) - set(onnxruntime_qnn_ctx_gen_src_patterns - "${onnxruntime_qnn_ctx_gen_src_dir}/*.cc" - "${onnxruntime_qnn_ctx_gen_src_dir}/*.h") + set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) + set(ep_weight_sharing_ctx_gen_src_patterns + "${ep_weight_sharing_ctx_gen_src_dir}/*.cc" + "${ep_weight_sharing_ctx_gen_src_dir}/*.h") - file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS - ${onnxruntime_qnn_ctx_gen_src_patterns} + file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS + ${ep_weight_sharing_ctx_gen_src_patterns} ) - onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src}) - target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} - ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} - ${CMAKE_CURRENT_BINARY_DIR}) + onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src}) + target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) if (WIN32) - target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings}) + target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) set(SYS_PATH_LIB shlwapi) endif() endif() if(WIN32) - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32) endif() - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnxruntime_session ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) - set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() # shared lib diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 0a00b56bf809d..8a96f59e9e790 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() { QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic."; } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; } Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 8f0f8414f9b3b..0fb719cbb0bbe 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -384,17 +384,26 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } - qnn_backend_manager_ = qnn::QnnBackendManager::Create( - qnn::QnnBackendManagerConfig{backend_path, - profiling_level_etw, - profiling_level, - profiling_file_path, - context_priority, - qnn_saver_path, - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing}); + // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits + // So that all graphs from later sessions will be compiled into the same QNN context + if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { + qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); + qnn_backend_manager_shared = true; + // Only allow the share across 2 sessions in case out of control + SharedContext::GetInstance().ResetSharedQnnBackendManager(); + } else { + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); + } #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { @@ -1028,6 +1037,12 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); + + if (share_ep_contexts_ && !qnn_backend_manager_shared && + nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) { + ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_), + "Failed to set shared QnnBackendManager."); + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 17cab3aabc3e4..42c65f8c58b47 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -77,6 +77,8 @@ class QNNExecutionProvider : public IExecutionProvider { std::shared_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; + // flag to know whether the shared QnnBackendManager is consumed. Stop passing it through if it is consumed. + bool qnn_backend_manager_shared = false; std::string context_cache_path_cfg_ = ""; std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 81de357dbe677..277a484ad8528 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -61,13 +61,39 @@ class SharedContext { return graph_exist; } + bool SetSharedQnnBackendManager(std::shared_ptr& qnn_backend_manager) { + const std::lock_guard lock(mtx_); + + if (qnn_backend_manager_ != nullptr) { + if (qnn_backend_manager_ == qnn_backend_manager) { + return true; + } + return false; + } + qnn_backend_manager_ = qnn_backend_manager; + return true; + } + + std::shared_ptr GetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + return qnn_backend_manager_; + } + + void ResetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + qnn_backend_manager_.reset(); + } + private: SharedContext() = default; ~SharedContext() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); + // Used for passing through QNN models (deserialized from context binary) across sessions std::vector> shared_qnn_models_; + // Used for compiling multiple models into same QNN context binary + std::shared_ptr qnn_backend_manager_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; diff --git a/onnxruntime/test/qnn_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md similarity index 82% rename from onnxruntime/test/qnn_ctx_gen/README.md rename to onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 97ab89d79cbd2..be1a1fe039366 100644 --- a/onnxruntime/test/qnn_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -2,17 +2,19 @@ This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below: -`onnxruntime_qnn_ctx_gen [options...] model_path,model_path` +`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path` -./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx +./ep_weight_sharing_ctx_gen -e qnn -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" /mnt/c/model1.onnx,/mnt/c/model2.onnx Options: - + + -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider qnn, tensorrt, openvino, vitisai. Default is qnn. + -v: Show verbose information. -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. - [Example] -C "ep.context_enable|1 ep.context_embed_mode|0" + [Example] -C "ep.context_enable|1 ep.context_embed_mode|0". These are set as default so can be ignored. -i: [provider_options]: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: [Usage]: -i '| |' diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc similarity index 68% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.cc rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index 24c343c7b9541..bf21d54ccde41 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "command_args_parser.h" @@ -29,28 +28,30 @@ namespace qnnctxgen { /*static*/ void CommandLineParser::ShowUsage() { printf( - "onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n" - "Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" + "ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n" + "Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" "Options:\n" + "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. " + "Default:'qnn'.\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t Force ep.context_enable to 1 and ep.context_embed_mode to 0. Change ep.context_file_path is not allowed." "\t [Example] -C \"ep.context_node_name_prefix|_part1\" \n" - "\t-i: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -i '| |'\n" "\n" - "\t [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" - "\t [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" - "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" - "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -109,8 +110,22 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { switch (ch) { + case 'e': + if (!CompareCString(optarg, ORT_TSTR("qnn"))) { + test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { + test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { + test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else { + fprintf(stderr, "The execution provider is not included in this tool.\n"); + return false; + } + break; case 'v': test_config.run_config.f_verbose = true; break; @@ -162,7 +177,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } - test_config.run_config.qnn_options[key] = value; + test_config.run_config.provider_options[key] = value; } break; } diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h similarity index 100% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc new file mode 100644 index 0000000000000..3d58ff424f120 --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_configuration.h" +#include "command_args_parser.h" + +// onnxruntime dependencies +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +// onnx dependencies +#include "onnx/onnx_pb.h" +#include + +using namespace onnxruntime; +using ProviderOptions = std::unordered_map; + +// from the last context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size +static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; + + onnx::ModelProto model; + std::ifstream onnx_file_stream(last_onnx_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + + for (auto& node : model.graph().node()) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + for (auto& attr : node.attribute()) { + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + } + if (attr.name() == "ep_cache_context") { + last_ctx_bin_file = attr.s(); + } + } + if (is_main_context) { + return; + } + } + } + + onnx_file_stream.close(); +} + +// Update generated context cache Onnx model to make the main EPContext node point to +// the last QNN context binary file +// Remove not used QNN context binary file, only keep the last one which contains all graphs +static void UpdateEpContextModel(const std::vector>& ep_ctx_files, + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { + for (auto ep_ctx_file : ep_ctx_files) { + onnx::ModelProto model; + std::ifstream onnx_file_stream(ep_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + onnx_file_stream.close(); + + for (auto& node : *(model.mutable_graph()->mutable_node())) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + std::string old_qnn_ctx_binary_file_name; + int max_size_index = 0; + int ep_context_index = 0; + for (auto i = 0; i < node.attribute_size(); ++i) { + auto& attr = node.attribute()[i]; + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + max_size_index = i; + } + if (attr.name() == "ep_cache_context") { + old_qnn_ctx_binary_file_name = attr.s(); + ep_context_index = 0; + } + } + if (is_main_context) { + auto path_str = ToPathString(ep_ctx_file); + auto path = std::filesystem::path(path_str); + auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); + std::remove(file_path.string().c_str()); + + node.mutable_attribute(max_size_index)->set_i(max_size); + node.mutable_attribute(ep_context_index)->set_s(last_qnn_ctx_binary_file_name); + } + } + } + + // re-write the onnx ctx file + std::ofstream onnx_file_ostream(ep_ctx_file, std::ios_base::binary); + model.SerializeToOstream(&onnx_file_ostream); + onnx_file_ostream.close(); + } +} + +#ifdef _WIN32 +int real_main(int argc, wchar_t* argv[]) { +#else +int real_main(int argc, char* argv[]) { +#endif + qnnctxgen::TestConfig test_config; + if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { + qnnctxgen::CommandLineParser::ShowUsage(); + return -1; + } + + OrtLoggingLevel logging_level = test_config.run_config.f_verbose + ? ORT_LOGGING_LEVEL_VERBOSE + : ORT_LOGGING_LEVEL_ERROR; + Ort::Env env(logging_level, "ep_weight_sharing"); + + ORT_TRY { + Ort::SessionOptions so; + so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); + // Set default session option to dump EPContext model with non-embed mode + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + ProviderOptions provider_options; + + for (auto it : test_config.run_config.provider_options) { + provider_options[it.first] = it.second; + } + + for (auto it : test_config.run_config.session_config_entries) { + if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { + std::cerr << "Need to enable ep context cache." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { + std::cerr << "Only support non-embed model for weight sharing." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextFilePath) { + std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; + continue; + } + so.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; + } + + // Generate context cache model files with QNN context binary files + // The context binary file generated later includes all graphs from previous models + { + std::string provider_name_ = test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kQnnExecutionProvider) { +#ifdef USE_QNN +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + if (test_config.model_file_paths.size() > 2) { + std::cerr << "QNN EP only support 2 models for the weight sharing feature."; + return -1; + } + + // set default QNN EP option to enable weight sharing if not set by user + const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; + if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { + provider_options[enable_htp_weight_sharing] = "1"; + } + so.AppendExecutionProvider("QNN", provider_options); +#else + ORT_THROW("QNN is not supported in this build\n"); +#endif + } else if (!provider_name_.empty()) { + ORT_THROW("This execution provider is not included in this tool.\n"); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; + Ort::Session session(env, model_path.c_str(), so); + } + } + + std::cout << "Start to update the generated Onnx model." << std::endl; + std::vector> ep_ctx_files; + ep_ctx_files.reserve(test_config.model_file_paths.size()); + for (auto model_path : test_config.model_file_paths) { + ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); + } + + // Get the last context binary file name + std::string last_qnn_ctx_binary_file_name; + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); + std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; + if (last_qnn_ctx_binary_file_name.empty()) { + throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); + } + ep_ctx_files.pop_back(); + + // Update generated context cache Onnx model to make the main EPContext node point to + // the last QNN context binary file + // Remove not used QNN context binary file, only keep the last one only which contains all graphs + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); + } + ORT_CATCH(const Ort::Exception& e) { + std::cerr << "Failed to generate context cache file: " << e.what(); + return -1; + } + + std::cout << "Generation done!"; + return 0; +} + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + int retval = -1; + ORT_TRY { + retval = real_main(argc, argv); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + }); + } + + ::google::protobuf::ShutdownProtobufLibrary(); + + return retval; +} diff --git a/onnxruntime/test/qnn_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h similarity index 75% rename from onnxruntime/test/qnn_ctx_gen/test_configuration.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index bf4c7061a3484..198d03211f561 100644 --- a/onnxruntime/test/qnn_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -14,15 +14,20 @@ namespace onnxruntime { namespace qnnctxgen { +struct MachineConfig { + std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; +}; + struct RunConfig { bool f_verbose{false}; std::unordered_map session_config_entries; - std::unordered_map qnn_options; + std::unordered_map provider_options; }; struct TestConfig { std::vector> model_file_paths; RunConfig run_config; + MachineConfig machine_config; }; } // namespace qnnctxgen diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 07843c30a61df..a7bb7953b3464 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -43,6 +43,35 @@ static const std::string& GetNodeAttr(const Node& node, const std::string& attr_ return default_val; } +// from the context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name +static void GetContextBinaryFileName(const std::string onnx_ctx_file, + std::string& last_ctx_bin_file, + const Logger& logger) { + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger)); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } +} + +// Get context binary file name from Context model file and remove it with the context model file +void CleanUpCtxFile(std::string context_file_path) { + std::string qnn_ctx_binary_file_name; + GetContextBinaryFileName(context_file_path, qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); + + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name.c_str()), 0); + ASSERT_EQ(std::remove(context_file_path.c_str()), 0); +} + // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | @@ -123,22 +152,22 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); int ep_context_node_count = 0; int non_ep_context_node_count = 0; std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { @@ -156,7 +185,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::SessionOptions so2; // context file path is required if it's non-embed mode and the model is loaded from memory - so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); std::string ctx_model_data; @@ -164,7 +193,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary @@ -237,7 +266,7 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { // clean up ASSERT_EQ(std::remove(model_with_ext.c_str()), 0); ASSERT_EQ(std::remove(model_ext_file_full_path.c_str()), 0); - ASSERT_EQ(std::remove(ep_context_model_file.c_str()), 0); + CleanUpCtxFile(ep_context_model_file); } // Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) @@ -444,21 +473,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Generate context cache model from the ONNX models with 2 inputs. @@ -481,26 +510,26 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); auto inputs = model->MainGraph().GetInputs(); EXPECT_TRUE(inputs.size() == 2); EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { @@ -519,20 +548,20 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); for (auto& node : model->MainGraph().Nodes()) { if (node.OpType() == "EPContext") { EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); @@ -540,7 +569,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { } // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -554,12 +583,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -577,7 +606,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), @@ -587,9 +616,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -884,12 +913,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -908,7 +937,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), @@ -918,9 +947,9 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_model_file); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node @@ -936,14 +965,14 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + const std::string context_model_file = "./qnn_context_cache_non_embed.onnx"; std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - std::remove(context_binary_file.c_str()); + std::remove(context_model_file.c_str()); std::remove(context_bin.string().c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -962,7 +991,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName session_option_pairs); // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_bin)); @@ -1001,7 +1030,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_model_file.c_str()), 0); ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } @@ -1053,44 +1082,19 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log static void DumpModelWithSharedCtx(const ProviderOptions& provider_options, const std::string& onnx_model_path1, const std::string& onnx_model_path2) { - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); - std::shared_ptr qnn_ep_shared(std::move(qnn_ep)); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - InferenceSessionWrapper session_object1{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1))); - ASSERT_STATUS_OK(session_object1.Initialize()); + so.AppendExecutionProvider("QNN", provider_options); - InferenceSessionWrapper session_object2{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2))); - ASSERT_STATUS_OK(session_object2.Initialize()); -} + // Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager + // to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary + Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so); -// from the last context ache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models -static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - const Logger& logger) { - std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger)); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } + Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } // Update generated context cache Onnx model to make the main EPContext node point to @@ -1172,10 +1176,10 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); - // Get the last context binary file name + // Get the last context binary file name, the latest context binary file holds all graphs generated from all models std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1272,8 +1276,8 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1336,6 +1340,61 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { } std::remove(last_qnn_ctx_binary_file_name.c_str()); } + +// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled +// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context +TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + ctx_model_paths.push_back(model_path + "_ctx.onnx"); + } + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so); + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so); + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2); + EXPECT_TRUE(file_size_2 > file_size_1); + + // clean up + for (auto model_path : onnx_model_paths) { + ASSERT_EQ(std::remove(model_path.c_str()), 0); + } + for (auto ctx_model_path : ctx_model_paths) { + ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0); + } + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc deleted file mode 100644 index bb5007b40b072..0000000000000 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// onnxruntime dependencies -#include "test_configuration.h" -#include -#include -#include -#include "command_args_parser.h" -#include - -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/inference_session.h" -#include "core/session/ort_env.h" -#include "core/providers/provider_factory_creators.h" -#include "core/common/logging/sinks/clog_sink.h" - -#include "core/graph/model.h" -#include "core/session/environment.h" -#include "core/common/logging/logging.h" - -using namespace onnxruntime; -const OrtApi* g_ort = NULL; -std::unique_ptr ort_env; - -static void CheckStatus(const Status& status) { - if (status.Code() != common::StatusCode::OK) { - std::string msg = status.ErrorMessage(); - throw Ort::Exception(std::move(msg), OrtErrorCode::ORT_FAIL); - } -} - -static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.i(); - } - - return default_val; -} - -static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.s(); - } - - return default_val; -} - -// from the last context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models -// get the max spill fill buffer size -static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - int64_t& max_size) { - max_size = 0; - std::shared_ptr ctx_model; - CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - max_size = GetNodeAttr(node, "max_size", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } -} - -// Update generated context cache Onnx model to make the main EPContext node point to -// the last QNN context binary file -// Remove not used QNN context binary file, only keep the last one which contains all graphs -static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name, - int64_t max_size) { - for (auto ep_ctx_file : ep_ctx_files) { - std::shared_ptr ctx_model; - auto path_str = ToPathString(ep_ctx_file); - CheckStatus(Model::Load(path_str, ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - GraphViewer graph_viewer(ctx_graph); - auto path = std::filesystem::path(path_str); - - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); - auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); - std::remove(file_path.string().c_str()); - node.ClearAttribute("ep_cache_context"); - node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); - node.ClearAttribute("max_size"); - node.AddAttribute("max_size", max_size); - } - } - } - std::remove(ToUTF8String(ep_ctx_file).c_str()); - CheckStatus(Model::Save(*ctx_model.get(), ToPathString(ep_ctx_file))); - } -} - -#ifdef _WIN32 -int real_main(int argc, wchar_t* argv[]) { -#else -int real_main(int argc, char* argv[]) { -#endif - g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - qnnctxgen::TestConfig test_config; - if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { - qnnctxgen::CommandLineParser::ShowUsage(); - return -1; - } - - { - bool failed = false; - ORT_TRY { - OrtLoggingLevel logging_level = test_config.run_config.f_verbose - ? ORT_LOGGING_LEVEL_VERBOSE - : ORT_LOGGING_LEVEL_WARNING; - - ort_env = std::make_unique(logging_level, "Default"); - } - ORT_CATCH(const Ort::Exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); - failed = true; - }); - } - - if (failed) - return -1; - } - - ORT_TRY { - SessionOptions so; - so.session_logid = "qnn_ctx_gen_session_logger"; - // Set default session option to dump QNN context model with non-embed mode - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - // set default QNN EP option to enable weight sharing - provider_options["enable_htp_weight_sharing"] = "1"; - - for (auto it : test_config.run_config.qnn_options) { - provider_options[it.first] = it.second; - } - - for (auto it : test_config.run_config.session_config_entries) { - if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { - std::cerr << "Need to enable ep context cache." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { - std::cerr << "Only support non-embed model for weight sharing." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextFilePath) { - std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; - continue; - } - CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); - } - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; - } - - // Generate context cache model files with QNN context binary files - // The context binary file generated later includes all graphs from previous models - { - auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); - std::shared_ptr qnn_ep(std::move(ep)); - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; - InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; - CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); - CheckStatus(session_object.Load(ToPathString(model_path))); - CheckStatus(session_object.Initialize()); - } - } - - std::cout << "Start to update the generated Onnx model." << std::endl; - std::vector> ep_ctx_files; - ep_ctx_files.reserve(test_config.model_file_paths.size()); - for (auto model_path : test_config.model_file_paths) { - ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); - } - - // Get the last context binary file name - std::string last_qnn_ctx_binary_file_name; - int64_t max_size = 0; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); - std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; - if (last_qnn_ctx_binary_file_name.empty()) { - throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); - } - ep_ctx_files.pop_back(); - - // Update generated context cache Onnx model to make the main EPContext node point to - // the last QNN context binary file - // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); - } - ORT_CATCH(const Ort::Exception& e) { - fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); - - ort_env.reset(); - return -1; - } - - ort_env.reset(); - - return 0; -} - -#ifdef _WIN32 -int wmain(int argc, wchar_t* argv[]) { -#else -int main(int argc, char* argv[]) { -#endif - int retval = -1; - ORT_TRY { - retval = real_main(argc, argv); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); - retval = -1; - }); - } - - ::google::protobuf::ShutdownProtobufLibrary(); - - return retval; -} diff --git a/setup.py b/setup.py index ced2f28e38778..53e533050b245 100644 --- a/setup.py +++ b/setup.py @@ -356,7 +356,7 @@ def finalize_options(self): "libQnnSaver.so", "libQnnSystem.so", "libHtpPrepare.so", - "onnxruntime_qnn_ctx_gen", + "ep_weight_sharing_ctx_gen", ] dl_libs.extend(qnn_deps) if nightly_build: