Skip to content

[Offload] Implement olShutDown #144055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion offload/liboffload/API/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def : Function {
let desc = "Release the resources in use by Offload";
let details = [
"This decrements an internal reference count. When this reaches 0, all resources will be released",
"Subsequent API calls made after this are not valid"
"Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED"
];
let params = [];
let returns = [];
Expand Down
55 changes: 40 additions & 15 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ struct AllocInfo {
// Global shared state for liboffload
struct OffloadContext;
static OffloadContext *OffloadContextVal;
std::mutex OffloadContextValMutex;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
Expand All @@ -107,6 +108,7 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
Expand Down Expand Up @@ -146,19 +148,19 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#include "Shared/Targets.def"

Error initPlugins() {
auto *Context = new OffloadContext{};
auto &Context = OffloadContext::get();

// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
Context->Platforms.emplace_back(ol_platform_impl_t{ \
Context.Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"

// Preemptively initialize all devices in the plugin
for (auto &Platform : Context->Platforms) {
for (auto &Platform : Context.Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
Expand All @@ -178,31 +180,54 @@ Error initPlugins() {
}

// Add the special host device
auto &HostPlatform = Context->Platforms.emplace_back(
auto &HostPlatform = Context.Platforms.emplace_back(
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
Context->HostDevice()->Platform = &HostPlatform;
Context.HostDevice()->Platform = &HostPlatform;

Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");

OffloadContextVal = Context;
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");

return Plugin::success();
}

// TODO: We can properly reference count here and manage the resources in a more
// clever way
Error olInit_impl() {
static std::once_flag InitFlag;
std::optional<Error> InitResult{};
std::call_once(InitFlag, [&] { InitResult = initPlugins(); });
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};

std::optional<Error> InitResult;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A std::optional<Error> is weird here. The result of Error::success() is supposed to be the optional result.

if (!isOffloadInitialized()) {
OffloadContextVal = new OffloadContext{};
InitResult = initPlugins();
}

OffloadContext::get().RefCount++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be calling new and delete on this if it's null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context is new'd in initPlugins(), but perhaps it makes sense doing it in olInit_impl, so I'll move it here.


if (InitResult)
return std::move(*InitResult);
return Error::success();
}
Error olShutDown_impl() { return Error::success(); }

Error olShutDown_impl() {
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};

if (--OffloadContext::get().RefCount != 0)
return Error::success();

llvm::Error Result = Error::success();

for (auto &P : OffloadContext::get().Platforms) {
// Host plugin is nullptr and has no deinit
if (!P.Plugin)
continue;
Comment on lines +219 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle this more cleanly in the future.


if (auto Res = P.Plugin->deinit())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
Comment on lines +223 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple plugins active this will potentially drop a previous error and hit an assertion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should collect all the errors from each plugin into a list. But I don't think that is actually handled correctly when returning it from the C api.

}
delete OffloadContextVal;
OffloadContextVal = nullptr;

return Result;
}

Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
Expand Down
12 changes: 12 additions & 0 deletions offload/unittests/OffloadAPI/init/olInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,20 @@

struct olInitTest : ::testing::Test {};

TEST_F(olInitTest, Success) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}

TEST_F(olInitTest, Uninitialized) {
ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
olIterateDevices(
[](ol_device_handle_t, void *) { return false; }, nullptr));
}

TEST_F(olInitTest, RepeatedInit) {
for (size_t I = 0; I < 10; I++) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}
}
Loading