Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the logic to generate the default ep context file name #23788

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -667,23 +667,28 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
}

// Validate the ep_context_path to make sure it is file path and check whether the file exist already
static Status EpContextFilePathCheck(const std::string& ep_context_path,
const std::filesystem::path& model_path) {
std::filesystem::path context_cache_path;
static Status EpContextFilePathCheckOrGet(const std::filesystem::path& ep_context_path,
const std::filesystem::path& model_path,
std::filesystem::path& context_cache_path) {
if (!ep_context_path.empty()) {
context_cache_path = ep_context_path;
if (!context_cache_path.has_filename()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder.");
}
} else if (!model_path.empty()) {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
auto pos = model_path.native().find_last_of(ORT_TSTR("."));
if (pos != std::string::npos) {
context_cache_path = model_path.native().substr(0, pos) + ORT_TSTR("_ctx.onnx");
} else {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty.");
}

if (std::filesystem::exists(context_cache_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '",
context_cache_path, "' exist already.");
context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it.");
}

return Status::OK();
Expand Down Expand Up @@ -714,15 +719,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
};

std::filesystem::path context_cache_path;
const std::filesystem::path& model_path = graph.ModelPath();

if (!ep_context_path.empty()) {
context_cache_path = ep_context_path;
} else if (!model_path.empty()) {
context_cache_path = model_path.native() + ORT_TSTR("_ctx.onnx");
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty");
}
ORT_RETURN_IF_ERROR(EpContextFilePathCheckOrGet(ep_context_path, graph.ModelPath(), context_cache_path));

Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
Expand Down Expand Up @@ -1068,7 +1065,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
if (ep_context_enabled) {
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
// Check before EP compile graphs
ORT_RETURN_IF_ERROR(EpContextFilePathCheck(ep_context_path, graph.ModelPath()));
std::filesystem::path context_cache_path;
ORT_RETURN_IF_ERROR(EpContextFilePathCheckOrGet(ep_context_path, graph.ModelPath(), context_cache_path));
}

// We use this only if Resource Aware Partitioning is enabled for any of the EPs
Expand Down
38 changes: 14 additions & 24 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,35 +198,13 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
return Status::OK();
}

// Figure out the real context cache file path
// return true if context cache file exists
bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path) {
// always try the path set by user first, it's the only way to set it if load model from memory
if (!customer_context_cache_path.empty()) {
context_cache_path = ToPathString(customer_context_cache_path);
} else if (!model_pathstring.empty()) { // model loaded from file
if (is_qnn_ctx_model) {
// it's a context cache model, just use the model path
context_cache_path = model_pathstring;
} else if (!model_pathstring.empty()) {
// this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}
}

return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
const onnxruntime::PathString& context_model_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger) {
Expand Down Expand Up @@ -262,7 +240,19 @@ Status CreateEPContextNodes(Model* model,
std::string cache_payload(buffer, buffer + buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
} else {
onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin");
onnxruntime::PathString context_bin_path;
auto pos = context_model_path.find_last_of(ORT_TSTR("."));
if (pos != std::string::npos) {
context_bin_path = context_model_path.substr(0, pos);
} else {
context_bin_path = context_model_path;
}
std::string graph_name_in_file(graph_name);
auto name_pos = graph_name_in_file.find_first_of(kQnnExecutionProvider);
if (name_pos != std::string::npos) {
graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), "");
}
context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin");
std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string());
std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary);
if (!of_stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph);

bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
onnxruntime::PathString& context_cache_path);

Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
Expand All @@ -67,7 +62,7 @@ Status CreateEPContextNodes(Model* model,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const onnxruntime::PathString& context_cache_path,
const onnxruntime::PathString& context_model_path,
bool qnn_context_embed_mode,
uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger);
Expand Down
30 changes: 19 additions & 11 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -904,25 +904,33 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector<FusedNodeAndG
return Status::OK();
}

// Figure out the context cache Onnx file path to decide the folder location
static void GetContextOnnxModelFilePath(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_path_string,
onnxruntime::PathString& context_cache_binary_path) {
// always try the path set by user first, it's the only way to set it if load model from memory
if (!customer_context_cache_path.empty()) {
context_cache_binary_path = ToPathString(customer_context_cache_path);
} else if (!model_path_string.empty()) { // model loaded from file
context_cache_binary_path = model_path_string;
}
}

Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
const auto& logger = *GetLogger();
bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs);

onnxruntime::PathString context_cache_path;
onnxruntime::PathString context_model_path;
bool is_ctx_file_exist = false;
if (is_qnn_ctx_model || context_cache_enabled_) {
const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph);
is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model,
context_cache_path_cfg_,
graph_viewer_0.ModelPath().native(),
context_cache_path);
// Figure out the EP context model path from model path or session option
GetContextOnnxModelFilePath(context_cache_path_cfg_,
graph_viewer_0.ModelPath().native(),
context_model_path);
}

ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_,
"The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ",
"Please remove the EP context model manually if you want to re-generate it.");

if (is_qnn_ctx_model) {
// Get QnnModel from EP shared contexts
if (share_ep_contexts_ && SharedContext::GetInstance().HasSharedQnnModels()) {
Expand Down Expand Up @@ -965,7 +973,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
// Create QNN context from the cached binary, deserialize the QNN graph from the binary
ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer,
context_cache_path,
context_model_path,
qnn_backend_manager_.get(),
qnn_models,
logger,
Expand Down Expand Up @@ -1025,7 +1033,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
qnn_backend_manager_->GetSdkVersion(),
fused_nodes_and_graphs,
qnn_models_,
context_cache_path,
context_model_path,
qnn_context_embed_mode_,
max_spill_fill_buffer_size,
logger));
Expand Down
Loading
Loading