diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 93c5ed54ab371..20918f8bc6de1 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -2,44 +2,157 @@ // Licensed under the MIT License #include "core/providers/qnn/rpcmem_library.h" + +#if defined(_WIN32) +#include + +#include +#include +#endif // defined(_WIN32) + #include "core/providers/qnn/ort_api.h" namespace onnxruntime::qnn { +// Unload the dynamic library referenced by `library_handle`. +// Avoid throwing because this may run from a dtor. +void DynamicLibraryHandleDeleter::operator()(void* library_handle) noexcept { + if (library_handle == nullptr) { + return; + } + + const auto& env = GetDefaultEnv(); + const auto unload_status = env.UnloadDynamicLibrary(library_handle); + + if (!unload_status.IsOK()) { + LOGS_DEFAULT(WARNING) << "Failed to unload dynamic library. Error: " << unload_status.ErrorMessage(); + } +} + namespace { -const PathChar* GetRpcMemSharedLibraryPath() { #if defined(_WIN32) - return ORT_TSTR("libcdsprpc.dll"); -#else - return ORT_TSTR("libcdsprpc.so"); -#endif + +struct ServiceHandleDeleter { + void operator()(SC_HANDLE handle) { ::CloseServiceHandle(handle); } +}; + +using UniqueServiceHandle = std::unique_ptr, ServiceHandleDeleter>; + +Status ReadEnvironmentVariable(const wchar_t* name, std::wstring& value_out) { + const DWORD value_size = ::GetEnvironmentVariableW(name, nullptr, 0); + ORT_RETURN_IF(value_size == 0, + "Failed to get environment variable length. GetEnvironmentVariableW error: ", ::GetLastError()); + + std::vector value(value_size); + + ORT_RETURN_IF(::GetEnvironmentVariableW(name, value.data(), value_size) == 0, + "Failed to get environment variable value. GetEnvironmentVariableW error: ", ::GetLastError()); + + value_out = std::wstring{value.data()}; + return Status::OK(); } -DynamicLibraryHandle LoadDynamicLibrary(const PathString& path, bool global_symbols) { - // Custom deleter to unload the shared library. Avoid throwing from it because it may run in dtor. - const auto unload_library = [](void* library_handle) { - if (library_handle == nullptr) { - return; - } +Status GetServiceBinaryDirectoryPath(const wchar_t* service_name, + std::filesystem::path& service_binary_directory_path_out) { + SC_HANDLE scm_handle_raw = ::OpenSCManagerW(nullptr, // local computer + nullptr, // SERVICES_ACTIVE_DATABASE + STANDARD_RIGHTS_READ); + ORT_RETURN_IF(scm_handle_raw == nullptr, + "Failed to open handle to service control manager. OpenSCManagerW error: ", ::GetLastError()); + + auto scm_handle = UniqueServiceHandle{scm_handle_raw}; + + SC_HANDLE service_handle_raw = ::OpenServiceW(scm_handle.get(), + service_name, + SERVICE_QUERY_CONFIG); + ORT_RETURN_IF(service_handle_raw == nullptr, + "Failed to open service handle. OpenServiceW error: ", ::GetLastError()); + + auto service_handle = UniqueServiceHandle{service_handle_raw}; + + // get service config required buffer size + DWORD service_config_buffer_size{}; + ORT_RETURN_IF(!::QueryServiceConfigW(service_handle.get(), nullptr, 0, &service_config_buffer_size) && + ::GetLastError() != ERROR_INSUFFICIENT_BUFFER, + "Failed to query service configuration buffer size. QueryServiceConfigW error: ", ::GetLastError()); - const auto& env = GetDefaultEnv(); - const auto unload_status = env.UnloadDynamicLibrary(library_handle); + // get the service config + std::vector service_config_buffer(service_config_buffer_size); + QUERY_SERVICE_CONFIGW* service_config = reinterpret_cast(service_config_buffer.data()); + ORT_RETURN_IF(!::QueryServiceConfigW(service_handle.get(), service_config, service_config_buffer_size, + &service_config_buffer_size), + "Failed to query service configuration. QueryServiceConfigW error: ", ::GetLastError()); - if (!unload_status.IsOK()) { - LOGS_DEFAULT(WARNING) << "Failed to unload shared library. Error: " << unload_status.ErrorMessage(); - } - }; + std::wstring service_binary_path_name = service_config->lpBinaryPathName; + // replace system root placeholder with the value of the SYSTEMROOT environment variable + const std::wstring system_root_placeholder = L"\\SystemRoot"; + + ORT_RETURN_IF(service_binary_path_name.find(system_root_placeholder, 0) != 0, + "Service binary path '", ToUTF8String(service_binary_path_name), + "' does not start with expected system root placeholder value '", + ToUTF8String(system_root_placeholder), "'."); + + std::wstring system_root{}; + ORT_RETURN_IF_ERROR(ReadEnvironmentVariable(L"SYSTEMROOT", system_root)); + service_binary_path_name.replace(0, system_root_placeholder.size(), system_root); + + const auto service_binary_path = std::filesystem::path{service_binary_path_name}; + auto service_binary_directory_path = service_binary_path.parent_path(); + + ORT_RETURN_IF(!std::filesystem::exists(service_binary_directory_path), + "Service binary directory path does not exist: ", service_binary_directory_path.string()); + + service_binary_directory_path_out = std::move(service_binary_directory_path); + return Status::OK(); +} + +#endif // defined(_WIN32) + +Status GetRpcMemDynamicLibraryPath(PathString& path_out) { +#if defined(_WIN32) + + std::filesystem::path qcnspmcdm_dir_path{}; + ORT_RETURN_IF_ERROR(GetServiceBinaryDirectoryPath(L"qcnspmcdm", qcnspmcdm_dir_path)); + const auto libcdsprpc_path = qcnspmcdm_dir_path / L"libcdsprpc.dll"; + path_out = libcdsprpc_path.wstring(); + return Status::OK(); + +#else // ^^^ defined(_WIN32) / vvv !defined(_WIN32) + + path_out = ORT_TSTR("libcdsprpc.so"); + return Status::OK(); + +#endif // !defined(_WIN32) +} + +Status LoadDynamicLibrary(const PathString& path, bool global_symbols, + UniqueDynamicLibraryHandle& library_handle_out) { const auto& env = GetDefaultEnv(); - void* library_handle = nullptr; + void* library_handle_raw = nullptr; + ORT_RETURN_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &library_handle_raw)); + + library_handle_out = UniqueDynamicLibraryHandle{library_handle_raw}; + return Status::OK(); +} + +UniqueDynamicLibraryHandle GetRpcMemDynamicLibraryHandle() { + std::string_view error_message_prefix = "Failed to initialize RPCMEM dynamic library handle: "; + + PathString rpcmem_library_path{}; + auto status = GetRpcMemDynamicLibraryPath(rpcmem_library_path); + if (!status.IsOK()) { + ORT_THROW(error_message_prefix, status.ErrorMessage()); + } - const auto load_status = env.LoadDynamicLibrary(path, global_symbols, &library_handle); - if (!load_status.IsOK()) { - ORT_THROW("Failed to load ", ToUTF8String(path), ": ", load_status.ErrorMessage()); + UniqueDynamicLibraryHandle library_handle{}; + status = LoadDynamicLibrary(rpcmem_library_path, /* global_symbols */ false, library_handle); + if (!status.IsOK()) { + ORT_THROW(error_message_prefix, status.ErrorMessage()); } - return DynamicLibraryHandle{library_handle, unload_library}; + return library_handle; } RpcMemApi CreateApi(void* library_handle) { @@ -58,7 +171,7 @@ RpcMemApi CreateApi(void* library_handle) { } // namespace RpcMemLibrary::RpcMemLibrary() - : library_handle_(LoadDynamicLibrary(GetRpcMemSharedLibraryPath(), /* global_symbols */ false)), + : library_handle_(GetRpcMemDynamicLibraryHandle()), api_{CreateApi(library_handle_.get())} { } diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h index 0642c96798188..2746e147373bb 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.h +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -10,7 +10,11 @@ namespace onnxruntime::qnn { -using DynamicLibraryHandle = std::unique_ptr; +struct DynamicLibraryHandleDeleter { + void operator()(void* library_handle) noexcept; +}; + +using UniqueDynamicLibraryHandle = std::unique_ptr; // This namespace contains constants and typedefs corresponding to functions from rpcmem.h. // https://github.com/quic/fastrpc/blob/v0.1.1/inc/rpcmem.h @@ -61,7 +65,7 @@ class RpcMemLibrary { const RpcMemApi& Api() const { return api_; } private: - DynamicLibraryHandle library_handle_; + UniqueDynamicLibraryHandle library_handle_; RpcMemApi api_; }; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 98d07fa06c009..0b51b6f8e503d 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1166,18 +1166,8 @@ TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { try { qnn_ep = QnnExecutionProviderWithOptions(provider_options); } catch (const OnnxRuntimeException& e) { - // handle particular exception that indicates that the libcdsprpc.so / dll can't be loaded - // NOTE: To run this on a local Windows ARM64 device, you need to copy libcdsprpc.dll to the build directory: - // - Open File Explorer - // - Go to C:/Windows/System32/DriverStore/FileRepository/ - // - Search for a folder that begins with qcnspmcdm8380.inf_arm64_ and open it - // - Copy the libcdsprpc.dll into the build/[PATH CONTAINING onnxruntime.dll] directory of the application. - // TODO(adrianlizarraga): Update CMake build for unittests to automatically copy libcdsprpc.dll into build directory -#if defined(_WIN32) - constexpr const char* expected_error_message = "Failed to load libcdsprpc.dll"; -#else - constexpr const char* expected_error_message = "Failed to load libcdsprpc.so"; -#endif + // handle exception that indicates that the libcdsprpc.so / dll can't be loaded + constexpr const char* expected_error_message = "Failed to initialize RPCMEM dynamic library handle"; ASSERT_THAT(e.what(), testing::HasSubstr(expected_error_message)); GTEST_SKIP() << "HTP shared memory allocator is unavailable."; } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 59920487a7248..ca9ca0f82a25a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1960,20 +1960,9 @@ static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model session = Ort::Session{*ort_env, model_path, session_options}; return true; } catch (const Ort::Exception& e) { - // handle particular exception that indicates that the libcdsprpc.so / dll can't be loaded - // NOTE: To run this on a local Windows ARM64 device, you need to copy libcdsprpc.dll to the build directory: - // - Open File Explorer - // - Go to C:/Windows/System32/DriverStore/FileRepository/ - // - Search for a folder that begins with qcnspmcdm8380.inf_arm64_ and open it - // - Copy the libcdsprpc.dll into the build/[PATH CONTAINING onnxruntime.dll] directory of the application. - // TODO(adrianlizarraga): Update CMake build for unittests to automatically copy libcdsprpc.dll into build directory + // handle exception that indicates that the libcdsprpc.so / dll can't be loaded std::string_view error_message = e.what(); - -#if defined(_WIN32) - std::string_view expected_error_message = "Failed to load libcdsprpc.dll"; -#else - std::string_view expected_error_message = "Failed to load libcdsprpc.so"; -#endif + std::string_view expected_error_message = "Failed to initialize RPCMEM dynamic library handle"; if (e.GetOrtErrorCode() == ORT_FAIL && error_message.find(expected_error_message) != std::string_view::npos) {