Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable QNN HTP spill fill buffer setting to save RAM usage. #22853

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models) {
QnnModelLookupTable& qnn_models,
uint32_t total_context_size) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
Expand All @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
main_context_node.Name(),
qnn_models);
qnn_models,
total_context_size);
}

std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
Expand Down Expand Up @@ -145,17 +147,19 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
main_context_node.Name(),
qnn_models);
qnn_models,
total_context_size);
}

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger) {
const logging::Logger& logger,
uint32_t total_context_size) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
qnn_models);
qnn_models, total_context_size);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models);
QnnModelLookupTable& qnn_models,
uint32_t total_context_size);

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger);
const logging::Logger& logger,
uint32_t total_context_size);

Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
Expand Down
69 changes: 46 additions & 23 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
Expand Down Expand Up @@ -531,11 +532,11 @@ Status QnnBackendManager::CreateContext() {
}

QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
QnnHtpContext_CustomConfig_t customConfig;
customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
context_config_weight_sharing.customConfig = &customConfig;
context_config_weight_sharing.customConfig = &custom_config;

QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
Expand Down Expand Up @@ -616,7 +617,8 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6

Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
QnnModelLookupTable& qnn_models) {
QnnModelLookupTable& qnn_models,
uint32_t total_context_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
Expand Down Expand Up @@ -657,13 +659,48 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
// HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28
if (total_context_size > 1 && max_spill_fill_buffer_ == 0) {
for (uint32_t i = 0; i < graph_count; ++i) {
if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
auto htp_graph_info = reinterpret_cast<QnnHtpSystemContext_GraphBlobInfo_t*>(graphs_info[i].graphInfoV3.graphBlobInfo);
if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
max_spill_fill_buffer_ = spill_fill_buffer_size > max_spill_fill_buffer_ ? spill_fill_buffer_size : max_spill_fill_buffer_;
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
}
} else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
} else {
LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
}
}
}

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};

// Register spill fill buffer for multi context
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
QnnHtpContext_CustomConfig_t custom_config;
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
QnnHtpContext_GroupRegistration_t group_info;
size_t current_contexts_size = GetQnnContextSize();
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
group_info.firstGroupHandle = current_contexts_size > 0 ? GetQnnContext(0) : 0x0;
group_info.maxSpillFillBuffer = max_spill_fill_buffer_; // Max spill-fill buffer across contexts. Must be >0
custom_config.groupRegistration = group_info;
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
spill_fill_config.customConfig = &custom_config;
QnnContext_Config_t* spill_fill_config_pointer =
(total_context_size > 1 && max_spill_fill_buffer_ > 0) ? &spill_fill_config : nullptr;

const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
Qnn_ContextHandle_t context = nullptr;
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
Expand All @@ -672,7 +709,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
buffer_length,
&context,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
if (1 == graph_count) {
// in case the EPContext node is generated from script
Expand Down Expand Up @@ -932,20 +969,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
return Status::OK();
}

void QnnBackendManager::Split(std::vector<std::string>& split_string,
const std::string& tokenized_string,
const char separator) {
split_string.clear();
std::istringstream tokenized_string_stream(tokenized_string);
while (!tokenized_string_stream.eof()) {
std::string value;
getline(tokenized_string_stream, value, separator);
if (!value.empty()) {
split_string.push_back(value);
}
}
}

Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class QnnBackendManager {

Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
uint32_t total_context_size);

Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);

Expand All @@ -112,6 +113,10 @@ class QnnBackendManager {
return contexts_[index];
}

size_t GetQnnContextSize() {
return contexts_.size();
}

const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; }

const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
Expand Down Expand Up @@ -145,8 +150,6 @@ class QnnBackendManager {

void ReleaseResources();

void Split(std::vector<std::string>& split_string, const std::string& tokenized_string, const char separator);

Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
Expand Down Expand Up @@ -268,6 +271,7 @@ class QnnBackendManager {
QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE;
uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN;
bool enable_htp_weight_sharing_ = false;
uint64_t max_spill_fill_buffer_ = 0;
};

} // namespace qnn
Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}

bool enable_htp_weight_sharing = false;
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
if ("1" == htp_weight_sharing_enabled_pos->second) {
enable_htp_weight_sharing_ = true;
enable_htp_weight_sharing = true;
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
enable_htp_weight_sharing_ = false;
enable_htp_weight_sharing = false;
} else {
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
<< " only 0 or 1 allowed. Set to 0.";
}
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_;
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
}

model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false,
Expand All @@ -396,7 +397,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
device_id_,
htp_arch,
soc_model,
enable_htp_weight_sharing_);
enable_htp_weight_sharing);

#ifdef _WIN32
auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance();
Expand Down Expand Up @@ -934,6 +935,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused

std::vector<int> main_context_pos_list;
ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list));
uint32_t total_context_size = SafeInt<uint32_t>(main_context_pos_list.size());

for (auto main_context_pos : main_context_pos_list) {
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
Expand All @@ -942,7 +944,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
context_cache_path,
qnn_backend_manager_.get(),
qnn_models,
logger));
logger,
total_context_size));
}

for (auto fused_node_and_graph : fused_nodes_and_graphs) {
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
bool enable_htp_weight_sharing_ = false;
int32_t vtcm_size_in_mb_ = 0;
std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
ModelMetadefIdGenerator metadef_id_generator_;
Expand Down
Loading