diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 111f8e0a5fc34..e2ba98c80026f 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -667,23 +667,28 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already -static Status EpContextFilePathCheck(const std::string& ep_context_path, - const std::filesystem::path& model_path) { - std::filesystem::path context_cache_path; +static Status EpContextFilePathCheckOrGet(const std::filesystem::path& ep_context_path, + const std::filesystem::path& model_path, + std::filesystem::path& context_cache_path) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; if (!context_cache_path.has_filename()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + auto pos = model_path.native().find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_cache_path = model_path.native().substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } if (std::filesystem::exists(context_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already."); + context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); } return Status::OK(); @@ -714,15 +719,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers }; std::filesystem::path context_cache_path; - const std::filesystem::path& model_path = graph.ModelPath(); - - if (!ep_context_path.empty()) { - context_cache_path = ep_context_path; - } else if (!model_path.empty()) { - context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx"); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty"); - } + ORT_RETURN_IF_ERROR(EpContextFilePathCheckOrGet(ep_context_path, graph.ModelPath(), context_cache_path)); Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path @@ -1068,7 +1065,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (ep_context_enabled) { std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); // Check before EP compile graphs - ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath())); + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(EpContextFilePathCheckOrGet(ep_context_path, graph.ModelPath(), context_cache_path)); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 3df231e53e7c0..d85277627a3de 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -198,35 +198,13 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -// Figure out the real context cache file path -// return true if context cache file exists -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // always try the path set by user first, it's the only way to set it if load model from memory - if (!customer_context_cache_path.empty()) { - context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { // model loaded from file - if (is_qnn_ctx_model) { - // it's a context cache model, just use the model path - context_cache_path = model_pathstring; - } else if (!model_pathstring.empty()) { - // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - } - - return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); -} - Status CreateEPContextNodes(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const QnnModelLookupTable& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger) { @@ -262,7 +240,19 @@ Status CreateEPContextNodes(Model* model, std::string cache_payload(buffer, buffer + buffer_size); ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload); } else { - onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin"); + onnxruntime::PathString context_bin_path; + auto pos = context_model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + context_bin_path = context_model_path.substr(0, pos); + } else { + context_bin_path = context_model_path; + } + std::string graph_name_in_file(graph_name); + auto name_pos = graph_name_in_file.find_first_of(kQnnExecutionProvider); + if (name_pos != std::string::npos) { + graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), ""); + } + context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin"); std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string()); std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary); if (!of_stream) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 3dfa0ae21001b..c54cd3ca6e90c 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -38,11 +38,6 @@ Status CreateNodeArgs(const std::vector& names, std::vector& node_args, onnxruntime::Graph& graph); -bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, - const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); - Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, @@ -67,7 +62,7 @@ Status CreateEPContextNodes(Model* model, const std::string& sdk_build_version, const std::vector& fused_nodes_and_graphs, const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, + const onnxruntime::PathString& context_model_path, bool qnn_context_embed_mode, uint64_t max_spill_fill_buffer_size, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3fc537066ae0b..99a6f51f6f712 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -904,25 +904,33 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); - onnxruntime::PathString context_cache_path; + onnxruntime::PathString context_model_path; bool is_ctx_file_exist = false; if (is_qnn_ctx_model || context_cache_enabled_) { const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); - is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, - context_cache_path_cfg_, - graph_viewer_0.ModelPath().native(), - context_cache_path); + // Figure out the EP context model path from model path or session option + GetContextOnnxModelFilePath(context_cache_path_cfg_, + graph_viewer_0.ModelPath().native(), + context_model_path); } - ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, - "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", - "Please remove the EP context model manually if you want to re-generate it."); - if (is_qnn_ctx_model) { // Get QnnModel from EP shared contexts if (share_ep_contexts_ && SharedContext::GetInstance().HasSharedQnnModels()) { @@ -965,7 +973,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); // Create QNN context from the cached binary, deserialize the QNN graph from the binary ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_cache_path, + context_model_path, qnn_backend_manager_.get(), qnn_models, logger, @@ -1025,7 +1033,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_backend_manager_->GetSdkVersion(), fused_nodes_and_graphs, qnn_models_, - context_cache_path, + context_model_path, qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 07843c30a61df..e50dd7c214240 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -333,7 +333,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); const std::string ep_context_onnx_file = "./ep_context_no_over_write.onnx"; - const std::string ep_context_binary_file = "./ep_context_no_over_write.onnx_QNNExecutionProvider_QNN_10880527342279992768_1_0.bin"; + const std::string ep_context_binary_file = "./ep_context_no_over_write_QNN_10880527342279992768_1_0.bin"; std::remove(ep_context_onnx_file.c_str()); Ort::SessionOptions so; @@ -580,6 +580,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -587,7 +589,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_binary_file, + session_option_pairs2); // Clean up ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } @@ -604,7 +607,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; - std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); @@ -686,7 +689,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { #endif provider_options["offload_graph_io_quantization"] = "0"; const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; std::remove(context_binary_file.c_str()); std::remove(context_bin.string().c_str()); @@ -828,6 +831,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./qnn_context_not_exist.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -841,7 +845,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -854,6 +858,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "./test_ctx.onnx")); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -867,7 +872,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { #endif provider_options["offload_graph_io_quantization"] = "0"; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); @@ -911,6 +916,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run directly loads and run from Qnn context cache model + std::unordered_map session_option_pairs2; + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, @@ -918,7 +925,8 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file); + context_binary_file, + session_option_pairs2); // Clean up ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } @@ -936,14 +944,14 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed_QNN_8283143575221199085_1_0.bin"; + std::remove(context_model_file.c_str()); std::remove(context_bin.string().c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); @@ -962,7 +970,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName session_option_pairs); // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // Check the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_bin)); @@ -990,18 +998,19 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName SessionOptions so; so.session_logid = "qnn_ctx_model_logger"; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str())); RunOptions run_options; run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options, &so))); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); // Verify the return status with code INVALID_GRAPH ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_model_file.c_str()), 0); ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } @@ -1167,7 +1176,13 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); @@ -1265,7 +1280,13 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { for (auto model_path : onnx_model_paths) { CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); - ctx_model_paths.push_back(model_path + "_ctx.onnx"); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); } DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]);