diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td index 79c3bd46f1984..669dfd3cca7c6 100644 --- a/offload/liboffload/API/Common.td +++ b/offload/liboffload/API/Common.td @@ -176,7 +176,7 @@ def : Function { let desc = "Release the resources in use by Offload"; let details = [ "This decrements an internal reference count. When this reaches 0, all resources will be released", - "Subsequent API calls made after this are not valid" + "Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED" ]; let params = []; let returns = []; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 6adebb25a2db0..9cf52171181f4 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -97,6 +97,7 @@ struct AllocInfo { // Global shared state for liboffload struct OffloadContext; static OffloadContext *OffloadContextVal; +std::mutex OffloadContextValMutex; struct OffloadContext { OffloadContext(OffloadContext &) = delete; OffloadContext(OffloadContext &&) = delete; @@ -107,6 +108,7 @@ struct OffloadContext { bool ValidationEnabled = true; DenseMap AllocInfoMap{}; SmallVector Platforms{}; + size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last @@ -146,19 +148,19 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #include "Shared/Targets.def" Error initPlugins() { - auto *Context = new OffloadContext{}; + auto &Context = OffloadContext::get(); // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Context->Platforms.emplace_back(ol_platform_impl_t{ \ + Context.Platforms.emplace_back(ol_platform_impl_t{ \ std::unique_ptr(createPlugin_##Name()), \ pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin - for (auto &Platform : Context->Platforms) { + for (auto &Platform : Context.Platforms) { // Do not use the host plugin - it isn't supported. if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) continue; @@ -178,31 +180,54 @@ Error initPlugins() { } // Add the special host device - auto &HostPlatform = Context->Platforms.emplace_back( + auto &HostPlatform = Context.Platforms.emplace_back( ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); - Context->HostDevice()->Platform = &HostPlatform; + Context.HostDevice()->Platform = &HostPlatform; - Context->TracingEnabled = std::getenv("OFFLOAD_TRACE"); - Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); - - OffloadContextVal = Context; + Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); + Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); return Plugin::success(); } -// TODO: We can properly reference count here and manage the resources in a more -// clever way Error olInit_impl() { - static std::once_flag InitFlag; - std::optional InitResult{}; - std::call_once(InitFlag, [&] { InitResult = initPlugins(); }); + std::lock_guard Lock{OffloadContextValMutex}; + + std::optional InitResult; + if (!isOffloadInitialized()) { + OffloadContextVal = new OffloadContext{}; + InitResult = initPlugins(); + } + + OffloadContext::get().RefCount++; if (InitResult) return std::move(*InitResult); return Error::success(); } -Error olShutDown_impl() { return Error::success(); } + +Error olShutDown_impl() { + std::lock_guard Lock{OffloadContextValMutex}; + + if (--OffloadContext::get().RefCount != 0) + return Error::success(); + + llvm::Error Result = Error::success(); + + for (auto &P : OffloadContext::get().Platforms) { + // Host plugin is nullptr and has no deinit + if (!P.Plugin) + continue; + + if (auto Res = P.Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); + } + delete OffloadContextVal; + OffloadContextVal = nullptr; + + return Result; +} Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, diff --git a/offload/unittests/OffloadAPI/init/olInit.cpp b/offload/unittests/OffloadAPI/init/olInit.cpp index 8e27e77cd0fb5..508615152b4f1 100644 --- a/offload/unittests/OffloadAPI/init/olInit.cpp +++ b/offload/unittests/OffloadAPI/init/olInit.cpp @@ -15,8 +15,20 @@ struct olInitTest : ::testing::Test {}; +TEST_F(olInitTest, Success) { + ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olShutDown()); +} + TEST_F(olInitTest, Uninitialized) { ASSERT_ERROR(OL_ERRC_UNINITIALIZED, olIterateDevices( [](ol_device_handle_t, void *) { return false; }, nullptr)); } + +TEST_F(olInitTest, RepeatedInit) { + for (size_t I = 0; I < 10; I++) { + ASSERT_SUCCESS(olInit()); + ASSERT_SUCCESS(olShutDown()); + } +}