diff --git a/unified-runtime/source/adapters/level_zero/adapter.cpp b/unified-runtime/source/adapters/level_zero/adapter.cpp index 6b23d0161a4f5..19958d148035b 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.cpp +++ b/unified-runtime/source/adapters/level_zero/adapter.cpp @@ -296,7 +296,7 @@ Behavior Summary: SysMan initialization is skipped. */ ur_adapter_handle_t_::ur_adapter_handle_t_() - : handle_base(), logger(logger::get_logger("level_zero")) { + : handle_base(), logger(logger::get_logger("level_zero")), RefCount(0) { ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED; ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED; ZesResult = ZE_RESULT_ERROR_UNINITIALIZED; @@ -675,7 +675,7 @@ ur_result_t urAdapterGet( } *Adapters = GlobalAdapter; - if (GlobalAdapter->RefCount++ == 0) { + if (GlobalAdapter->getRefCount().retain() == 0) { adapterStateInit(); } } @@ -692,7 +692,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) { // NOTE: This does not require guarding with a mutex; the instant the ref // count hits zero, both Get and Retain are UB. - if (--GlobalAdapter->RefCount == 0) { + if (GlobalAdapter->getRefCount().release()) { auto result = adapterStateTeardown(); #ifdef UR_STATIC_LEVEL_ZERO // Given static linking of the L0 Loader, we must delay the loader's @@ -711,7 +711,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) { ur_result_t urAdapterRetain([[maybe_unused]] ur_adapter_handle_t Adapter) { assert(GlobalAdapter && GlobalAdapter == Adapter); - GlobalAdapter->RefCount++; + GlobalAdapter->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -740,7 +740,7 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_LEVEL_ZERO); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(GlobalAdapter->RefCount.load()); + return ReturnValue(GlobalAdapter->getRefCount().getCount()); case UR_ADAPTER_INFO_VERSION: { #ifdef UR_ADAPTER_LEVEL_ZERO_V2 uint32_t adapterVersion = 2; diff --git a/unified-runtime/source/adapters/level_zero/adapter.hpp b/unified-runtime/source/adapters/level_zero/adapter.hpp index bb0a9058bce1b..3549713c02068 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.hpp +++ b/unified-runtime/source/adapters/level_zero/adapter.hpp @@ -9,9 +9,9 @@ //===----------------------------------------------------------------------===// #pragma once +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" -#include #include #include #include @@ -26,7 +26,6 @@ class ur_legacy_sink; struct ur_adapter_handle_t_ : ur::handle_base { ur_adapter_handle_t_(); - std::atomic RefCount = 0; zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr; zes_pfnDriverGet_t getSysManDriversFunctionPtr = nullptr; @@ -45,6 +44,11 @@ struct ur_adapter_handle_t_ : ur::handle_base { ZeCache> PlatformCache; logger::Logger &logger; HMODULE processHandle = nullptr; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; extern ur_adapter_handle_t_ *GlobalAdapter; diff --git a/unified-runtime/source/adapters/level_zero/async_alloc.cpp b/unified-runtime/source/adapters/level_zero/async_alloc.cpp index 204b43c3bcc79..201b5b4d17285 100644 --- a/unified-runtime/source/adapters/level_zero/async_alloc.cpp +++ b/unified-runtime/source/adapters/level_zero/async_alloc.cpp @@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp( } size_t size = umfPoolMallocUsableSize(hPool, Mem); - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); usmPool->AsyncPool.insert(Mem, size, *Event, Queue); // Signal that USM free event was finished diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index a69f23f286dfd..87ef900d4927d 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -842,13 +842,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) { - CommandBuffer->RefCount.increment(); + CommandBuffer->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) { - if (!CommandBuffer->RefCount.decrementAndTest()) + if (!CommandBuffer->getRefCount().release()) return UR_RESULT_SUCCESS; UR_CALL(waitForOngoingExecution(CommandBuffer)); @@ -1643,7 +1643,7 @@ ur_result_t enqueueImmediateAppendPath( if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(CommandListHelper, false, false)); @@ -1726,7 +1726,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer, if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/, @@ -1850,7 +1850,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->getRefCount().getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/command_buffer.hpp index f7b62a9c8dd1e..bd9e8fd97310a 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.hpp @@ -17,6 +17,7 @@ #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "kernel.hpp" #include "queue.hpp" @@ -149,4 +150,9 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { // Track handle objects to free when command-buffer is destroyed. std::vector> CommandHandles; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/common.hpp b/unified-runtime/source/adapters/level_zero/common.hpp index 19e22de14605d..c33afe4144e75 100644 --- a/unified-runtime/source/adapters/level_zero/common.hpp +++ b/unified-runtime/source/adapters/level_zero/common.hpp @@ -34,6 +34,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" @@ -220,55 +221,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString); #define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \ ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false) -// This wrapper around std::atomic is created to limit operations with reference -// counter and to make allowed operations more transparent in terms of -// thread-safety in the plugin. increment() and load() operations do not need a -// mutex guard around them since the underlying data is already atomic. -// decrementAndTest() method is used to guard a code which needs to be -// executed when object's ref count becomes zero after release. This method also -// doesn't need a mutex guard because decrement operation is atomic and only one -// thread can reach ref count equal to zero, i.e. only a single thread can pass -// through this check. -struct ReferenceCounter { - ReferenceCounter() : RefCount{1} {} - - // Reset the counter to the initial value. - void reset() { RefCount = 1; } - - // Used when retaining an object. - void increment() { RefCount++; } - - // Supposed to be used in ur*GetInfo* methods where ref count value is - // requested. - uint32_t load() { return RefCount.load(); } - - // This method allows to guard a code which needs to be executed when object's - // ref count becomes zero after release. It is important to notice that only a - // single thread can pass through this check. This is true because of several - // reasons: - // 1. Decrement operation is executed atomically. - // 2. It is not allowed to retain an object after its refcount reaches zero. - // 3. It is not allowed to release an object more times than the value of - // the ref count. - // 2. and 3. basically means that we can't use an object at all as soon as its - // refcount reaches zero. Using this check guarantees that code for deleting - // an object and releasing its resources is executed once by a single thread - // and we don't need to use any mutexes to guard access to this object in the - // scope after this check. Of course if we access another objects in this code - // (not the one which is being deleted) then access to these objects must be - // guarded, for example with a mutex. - bool decrementAndTest() { return --RefCount == 0; } - -private: - std::atomic RefCount; -}; - // Base class to store common data struct ur_object : ur::handle_base { - ur_object() : handle_base(), RefCount{} {} - - // Must be atomic to prevent data race when incrementing/decrementing. - ReferenceCounter RefCount; + ur_object() : handle_base() {} // This mutex protects accesses to all the non-const member variables. // Exclusive access is required to modify any of these members. @@ -303,6 +258,11 @@ struct MemAllocRecord : ur_object { // TODO: this should go away when memory isolation issue is fixed in the Level // Zero runtime. ur_context_handle_t Context; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; extern usm::DisjointPoolAllConfigs DisjointPoolConfigInstance; diff --git a/unified-runtime/source/adapters/level_zero/context.cpp b/unified-runtime/source/adapters/level_zero/context.cpp index 3209b8b789155..275f39c478115 100644 --- a/unified-runtime/source/adapters/level_zero/context.cpp +++ b/unified-runtime/source/adapters/level_zero/context.cpp @@ -61,7 +61,7 @@ ur_result_t urContextRetain( /// [in] handle of the context to get a reference of. ur_context_handle_t Context) { - Context->RefCount.increment(); + Context->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -113,7 +113,7 @@ ur_result_t urContextGetInfo( case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(Context->Devices.size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Context->RefCount.load()}); + return ReturnValue(uint32_t{Context->getRefCount().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(uint8_t{UseMemcpy2DOperations}); @@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const { // from the list of tracked contexts. ur_result_t ContextReleaseHelper(ur_context_handle_t Context) { - if (!Context->RefCount.decrementAndTest()) + if (!Context->getRefCount().release()) return UR_RESULT_SUCCESS; if (IndirectAccessTrackingEnabled) { diff --git a/unified-runtime/source/adapters/level_zero/context.hpp b/unified-runtime/source/adapters/level_zero/context.hpp index 86e0ea27b5c3e..b1bb958f56257 100644 --- a/unified-runtime/source/adapters/level_zero/context.hpp +++ b/unified-runtime/source/adapters/level_zero/context.hpp @@ -26,6 +26,7 @@ #include "queue.hpp" #include "usm.hpp" +#include "common/ur_ref_count.hpp" #include struct l0_command_list_cache_info { @@ -358,6 +359,8 @@ struct ur_context_handle_t_ : ur_object { // Get handle to the L0 context ze_context_handle_t getZeHandle() const; + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: enum EventFlags { EVENT_FLAG_HOST_VISIBLE = UR_BIT(0), @@ -404,6 +407,8 @@ struct ur_context_handle_t_ : ur_object { return &EventCaches[index]; } + + ur::RefCount RefCount; }; // Helper function to release the context, a caller must lock the platform-level diff --git a/unified-runtime/source/adapters/level_zero/device.cpp b/unified-runtime/source/adapters/level_zero/device.cpp index 6392b3802a199..2cd11cb3a957e 100644 --- a/unified-runtime/source/adapters/level_zero/device.cpp +++ b/unified-runtime/source/adapters/level_zero/device.cpp @@ -470,7 +470,7 @@ ur_result_t urDeviceGetInfo( return ReturnValue((uint32_t)Device->SubDevices.size()); } case UR_DEVICE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Device->RefCount.load()}); + return ReturnValue(uint32_t{Device->getRefCount().getCount()}); case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: { // SYCL spec says: if this SYCL device cannot be partitioned into at least // two sub devices then the returned vector must be empty. @@ -1666,7 +1666,7 @@ ur_result_t urDeviceGetGlobalTimestamps( ur_result_t urDeviceRetain(ur_device_handle_t Device) { // The root-device ref-count remains unchanged (always 1). if (Device->isSubDevice()) { - Device->RefCount.increment(); + Device->getRefCount().retain(); } return UR_RESULT_SUCCESS; } @@ -1674,7 +1674,7 @@ ur_result_t urDeviceRetain(ur_device_handle_t Device) { ur_result_t urDeviceRelease(ur_device_handle_t Device) { // Root devices are destroyed during the piTearDown process. if (Device->isSubDevice()) { - if (Device->RefCount.decrementAndTest()) { + if (Device->getRefCount().release()) { delete Device; } } diff --git a/unified-runtime/source/adapters/level_zero/device.hpp b/unified-runtime/source/adapters/level_zero/device.hpp index a8326c0cf668b..34dd5fcad725d 100644 --- a/unified-runtime/source/adapters/level_zero/device.hpp +++ b/unified-runtime/source/adapters/level_zero/device.hpp @@ -20,6 +20,7 @@ #include "adapters/level_zero/platform.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include #include #include @@ -242,6 +243,11 @@ struct ur_device_handle_t_ : ur_object { // unique ephemeral identifer of the device in the adapter std::optional Id; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; inline std::vector diff --git a/unified-runtime/source/adapters/level_zero/event.cpp b/unified-runtime/source/adapters/level_zero/event.cpp index f06cae5ec0cb3..a31dc1ff72bf6 100644 --- a/unified-runtime/source/adapters/level_zero/event.cpp +++ b/unified-runtime/source/adapters/level_zero/event.cpp @@ -505,7 +505,7 @@ ur_result_t urEventGetInfo( return ReturnValue(Result); } case UR_EVENT_INFO_REFERENCE_COUNT: { - return ReturnValue(Event->RefCount.load()); + return ReturnValue(Event->getRefCount().getCount()); } default: UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})", @@ -874,7 +874,7 @@ ur_result_t /// [in] handle of the event object urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) { Event->RefCountExternal++; - Event->RefCount.increment(); + Event->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() { ur_result_t urEventReleaseInternal(ur_event_handle_t Event, bool *isEventDeleted) { - if (!Event->RefCount.decrementAndTest()) + if (!Event->getRefCount().release()) return UR_RESULT_SUCCESS; if (Event->OriginAllocEvent) { @@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( std::shared_lock Lock(CurQueue->LastCommandEvent->Mutex); this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent; this->UrEventList[0] = CurQueue->LastCommandEvent; - this->UrEventList[0]->RefCount.increment(); + this->UrEventList[0]->getRefCount().retain(); TmpListLength = 1; } else if (EventListLength > 0) { this->ZeEventList = new ze_event_handle_t[EventListLength]; @@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( IsInternal, IsMultiDevice)); MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent; const auto &ZeCommandList = CommandList->first; - EventList[I]->RefCount.increment(); + EventList[I]->getRefCount().retain(); // Append a Barrier to wait on the original event while signalling the // new multi device event. @@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( this->ZeEventList[TmpListLength] = MultiDeviceZeEvent; this->UrEventList[TmpListLength] = MultiDeviceEvent; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->getRefCount().retain(); } else { this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent; this->UrEventList[TmpListLength] = EventList[I]; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->getRefCount().retain(); } if (QueueLock.has_value()) { diff --git a/unified-runtime/source/adapters/level_zero/event.hpp b/unified-runtime/source/adapters/level_zero/event.hpp index 13b36bcdfbe94..80565164902df 100644 --- a/unified-runtime/source/adapters/level_zero/event.hpp +++ b/unified-runtime/source/adapters/level_zero/event.hpp @@ -25,6 +25,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "queue.hpp" #include "ur_api.h" @@ -262,6 +263,11 @@ struct ur_event_handle_t_ : ur_object { // Used only for asynchronous allocations. This is the event originally used // on async free to indicate when the allocation can be used again. ur_event_handle_t OriginAllocEvent = nullptr; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; // Helper function to implement zeHostSynchronize. diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index 29c6b19e2bbfe..ab50dd6708a29 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -787,7 +787,7 @@ ur_result_t urKernelGetInfo( } } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Kernel->RefCount.load()}); + return ReturnValue(uint32_t{Kernel->getRefCount().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: try { uint32_t Size; @@ -938,7 +938,7 @@ ur_result_t urKernelGetSubGroupInfo( ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t Kernel) { - Kernel->RefCount.increment(); + Kernel->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -946,7 +946,7 @@ ur_result_t urKernelRetain( ur_result_t urKernelRelease( /// [in] handle for the Kernel to release ur_kernel_handle_t Kernel) { - if (!Kernel->RefCount.decrementAndTest()) + if (!Kernel->getRefCount().release()) return UR_RESULT_SUCCESS; auto KernelProgram = Kernel->Program; diff --git a/unified-runtime/source/adapters/level_zero/kernel.hpp b/unified-runtime/source/adapters/level_zero/kernel.hpp index 7f80348cda31f..b56ea0fa89cd9 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/kernel.hpp @@ -9,9 +9,11 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" -#include struct ur_kernel_handle_t_ : ur_object { ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program) @@ -106,6 +108,11 @@ struct ur_kernel_handle_t_ : ur_object { // Cache of the kernel properties. ZeCache> ZeKernelProperties; ZeCache ZeKernelName; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; ur_result_t getZeKernel(ze_device_handle_t hDevice, ur_kernel_handle_t hKernel, diff --git a/unified-runtime/source/adapters/level_zero/memory.cpp b/unified-runtime/source/adapters/level_zero/memory.cpp index 0f6bb37dde904..e1139ff0e2c80 100644 --- a/unified-runtime/source/adapters/level_zero/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/memory.cpp @@ -1052,7 +1052,7 @@ ur_result_t urEnqueueMemBufferMap( // Add the event to the command list. CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); const auto &ZeCommandList = CommandList->first; const auto &WaitList = (*Event)->WaitList; @@ -1183,7 +1183,7 @@ ur_result_t urEnqueueMemUnmap( nullptr /*ForcedCmdQueue*/)); CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); const auto &ZeCommandList = CommandList->first; @@ -1635,14 +1635,14 @@ ur_result_t urMemBufferCreate( ur_result_t urMemRetain( /// [in] handle of the memory object to get access ur_mem_handle_t Mem) { - Mem->RefCount.increment(); + Mem->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t urMemRelease( /// [in] handle of the memory object to release ur_mem_handle_t Mem) { - if (!Mem->RefCount.decrementAndTest()) + if (!Mem->getRefCount().release()) return UR_RESULT_SUCCESS; if (Mem->isImage()) { @@ -1848,7 +1848,7 @@ ur_result_t urMemGetInfo( return ReturnValue(size_t{Buffer->Size}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(Buffer->RefCount.load()); + return ReturnValue(Buffer->getRefCount().getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/memory.hpp b/unified-runtime/source/adapters/level_zero/memory.hpp index 715b5b51870c1..7bf7234293f3d 100644 --- a/unified-runtime/source/adapters/level_zero/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/memory.hpp @@ -19,6 +19,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" #include "program.hpp" @@ -90,6 +91,8 @@ struct ur_mem_handle_t_ : ur_object { // Method to get type of the derived object (image or buffer) bool isImage() const { return mem_type == mem_type_t::image; } + ur::RefCount &getRefCount() noexcept { return RefCount; } + protected: ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context) : UrContext{Context}, UrDevice{nullptr}, mem_type(type) {} @@ -101,6 +104,9 @@ struct ur_mem_handle_t_ : ur_object { // Since the destructor isn't virtual, callers must destruct it via ur_buffer // or ur_image ~ur_mem_handle_t_() {}; + +private: + ur::RefCount RefCount; }; struct ur_buffer final : ur_mem_handle_t_ { @@ -116,7 +122,7 @@ struct ur_buffer final : ur_mem_handle_t_ { : ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size), SubBuffer{{Parent, Origin}} { // Retain the Parent Buffer due to the Creation of the SubBuffer. - Parent->RefCount.increment(); + Parent->getRefCount().retain(); } // Interop-buffer constructor diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.cpp b/unified-runtime/source/adapters/level_zero/physical_mem.cpp index 5d4d0acce0eb3..00debb98fe03e 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.cpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.cpp @@ -42,12 +42,12 @@ ur_result_t urPhysicalMemCreate( } ur_result_t urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) { - hPhysicalMem->RefCount.increment(); + hPhysicalMem->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) { - if (!hPhysicalMem->RefCount.decrementAndTest()) + if (!hPhysicalMem->getRefCount().release()) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { @@ -68,7 +68,7 @@ ur_result_t urPhysicalMemGetInfo(ur_physical_mem_handle_t hPhysicalMem, switch (propName) { case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hPhysicalMem->RefCount.load()); + return ReturnValue(hPhysicalMem->getRefCount().getCount()); } default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.hpp b/unified-runtime/source/adapters/level_zero/physical_mem.hpp index 6ce630bcc5e1f..8af78bfb3da43 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.hpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" struct ur_physical_mem_handle_t_ : ur_object { ur_physical_mem_handle_t_(ze_physical_mem_handle_t ZePhysicalMem, @@ -21,4 +22,9 @@ struct ur_physical_mem_handle_t_ : ur_object { // Keeps the PI context of this memory handle. ur_context_handle_t Context; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 497e3057b7b9b..edcff4e628729 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -558,14 +558,14 @@ ur_result_t urProgramLinkExp( ur_result_t urProgramRetain( /// [in] handle for the Program to retain ur_program_handle_t Program) { - Program->RefCount.increment(); + Program->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t urProgramRelease( /// [in] handle for the Program to release ur_program_handle_t Program) { - if (!Program->RefCount.decrementAndTest()) + if (!Program->getRefCount().release()) return UR_RESULT_SUCCESS; delete Program; @@ -708,7 +708,7 @@ ur_result_t urProgramGetInfo( switch (PropName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Program->RefCount.load()}); + return ReturnValue(uint32_t{Program->getRefCount().getCount()}); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(Program->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -1115,7 +1115,7 @@ void ur_program_handle_t_::ur_release_program_resources(bool deletion) { // must be destroyed before the Module can be destroyed. So, be sure // to destroy build log before destroying the module. if (!deletion) { - if (!RefCount.decrementAndTest()) { + if (!RefCount.release()) { return; } } diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index 789daf052ba0c..cd73cd43c7031 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" struct ur_program_handle_t_ : ur_object { @@ -226,6 +227,8 @@ struct ur_program_handle_t_ : ur_object { // UR_PROGRAM_INFO_BINARY_SIZES. const std::vector AssociatedDevices; + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: struct DeviceData { // Log from the result of building the program for the device using @@ -264,4 +267,6 @@ struct ur_program_handle_t_ : ur_object { // handle from the program. // TODO: Currently interoparability UR API does not support multiple devices. ze_module_handle_t InteropZeModule = nullptr; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/queue.cpp b/unified-runtime/source/adapters/level_zero/queue.cpp index dc7d19ffaa007..cb01619041921 100644 --- a/unified-runtime/source/adapters/level_zero/queue.cpp +++ b/unified-runtime/source/adapters/level_zero/queue.cpp @@ -369,7 +369,7 @@ ur_result_t urQueueGetInfo( case UR_QUEUE_INFO_DEVICE: return ReturnValue(Queue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Queue->RefCount.load()}); + return ReturnValue(uint32_t{Queue->getRefCount().getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(Queue->Properties); case UR_QUEUE_INFO_SIZE: @@ -593,7 +593,7 @@ ur_result_t urQueueRetain( std::scoped_lock Lock(Queue->Mutex); Queue->RefCountExternal++; } - Queue->RefCount.increment(); + Queue->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -612,7 +612,7 @@ ur_result_t urQueueRelease( // internal reference count. When the External Reference count == 0, then // cleanup of the queue begins and the final decrement of the internal // reference count is completed. - static_cast(Queue->RefCount.decrementAndTest()); + static_cast(Queue->getRefCount().release()); return UR_RESULT_SUCCESS; } @@ -1389,7 +1389,7 @@ ur_queue_handle_t_::executeCommandList(ur_command_list_ptr_t CommandList, if (!Event->HostVisibleEvent) { Event->HostVisibleEvent = reinterpret_cast(HostVisibleEvent); - HostVisibleEvent->RefCount.increment(); + HostVisibleEvent->getRefCount().retain(); } } @@ -1550,7 +1550,7 @@ ur_result_t ur_queue_handle_t_::addEventToQueueCache(ur_event_handle_t Event) { } void ur_queue_handle_t_::active_barriers::add(ur_event_handle_t &Event) { - Event->RefCount.increment(); + Event->getRefCount().retain(); Events.push_back(Event); } @@ -1588,7 +1588,7 @@ void ur_queue_handle_t_::clearEndTimeRecordings() { } ur_result_t urQueueReleaseInternal(ur_queue_handle_t Queue) { - if (!Queue->RefCount.decrementAndTest()) + if (!Queue->getRefCount().release()) return UR_RESULT_SUCCESS; for (auto &Cache : Queue->EventCaches) { @@ -1921,7 +1921,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // Append this Event to the CommandList, if any if (CommandList != Queue->CommandListMap.end()) { CommandList->second.append(*Event); - (*Event)->RefCount.increment(); + (*Event)->getRefCount().retain(); } // We need to increment the reference counter here to avoid ur_queue_handle_t @@ -1929,7 +1929,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // urEventRelease requires access to the associated ur_queue_handle_t. // In urEventRelease, the reference counter of the Queue is decremented // to release it. - Queue->RefCount.increment(); + Queue->getRefCount().retain(); // SYCL RT does not track completion of the events, so it could // release a PI event as soon as that's not being waited in the app. @@ -1961,7 +1961,7 @@ void ur_queue_handle_t_::CaptureIndirectAccesses() { // SubmissionsCount turns to 0. We don't want to know how many times // allocation was retained by each submission. if (Pair.second) - Elem.second.RefCount.increment(); + Elem.second.getRefCount().retain(); } } Kernel->SubmissionsCount++; diff --git a/unified-runtime/source/adapters/level_zero/queue.hpp b/unified-runtime/source/adapters/level_zero/queue.hpp index 405929c8f0f0e..daa33a76f47b0 100644 --- a/unified-runtime/source/adapters/level_zero/queue.hpp +++ b/unified-runtime/source/adapters/level_zero/queue.hpp @@ -25,6 +25,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" extern "C" { @@ -692,6 +693,11 @@ struct ur_queue_handle_t_ : ur_object { // Pointer to the unified handle. ur_queue_handle_t_ *UnifiedHandle; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; // This helper function creates a ur_event_handle_t and associate a diff --git a/unified-runtime/source/adapters/level_zero/sampler.cpp b/unified-runtime/source/adapters/level_zero/sampler.cpp index 4f6f5760faada..4d9d80b77c6e3 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.cpp +++ b/unified-runtime/source/adapters/level_zero/sampler.cpp @@ -124,14 +124,14 @@ ur_result_t urSamplerCreate( ur_result_t urSamplerRetain( /// [in] handle of the sampler object to get access ur_sampler_handle_t Sampler) { - Sampler->RefCount.increment(); + Sampler->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t urSamplerRelease( /// [in] handle of the sampler object to release ur_sampler_handle_t Sampler) { - if (!Sampler->RefCount.decrementAndTest()) + if (!Sampler->getRefCount().release()) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { diff --git a/unified-runtime/source/adapters/level_zero/sampler.hpp b/unified-runtime/source/adapters/level_zero/sampler.hpp index 9a834a05215d9..48584db343dda 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.hpp +++ b/unified-runtime/source/adapters/level_zero/sampler.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" struct ur_sampler_handle_t_ : ur_object { ur_sampler_handle_t_(ze_sampler_handle_t Sampler) : ZeSampler{Sampler} {} @@ -18,6 +19,11 @@ struct ur_sampler_handle_t_ : ur_object { ze_sampler_handle_t ZeSampler; ZeStruct ZeSamplerDesc; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; // Construct ZE sampler desc from UR sampler desc. diff --git a/unified-runtime/source/adapters/level_zero/usm.cpp b/unified-runtime/source/adapters/level_zero/usm.cpp index ab8556dd692c7..1b7745924d476 100644 --- a/unified-runtime/source/adapters/level_zero/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/usm.cpp @@ -523,14 +523,14 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t Pool) { - Pool->RefCount.increment(); + Pool->getRefCount().retain(); return UR_RESULT_SUCCESS; } ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t Pool) { - if (Pool->RefCount.decrementAndTest()) { + if (Pool->getRefCount().release()) { std::scoped_lock ContextLock(Pool->Context->Mutex); Pool->Context->UsmPoolHandles.remove(Pool); delete Pool; @@ -553,7 +553,7 @@ ur_result_t urUSMPoolGetInfo( switch (PropName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(Pool->RefCount.load()); + return ReturnValue(Pool->getRefCount().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(Pool->Context); @@ -1250,7 +1250,7 @@ ur_result_t ZeMemFreeHelper(ur_context_handle_t Context, void *Ptr) { if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.getRefCount().release()) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } @@ -1297,7 +1297,7 @@ ur_result_t USMFreeHelper(ur_context_handle_t Context, void *Ptr, if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.getRefCount().release()) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/level_zero/usm.hpp b/unified-runtime/source/adapters/level_zero/usm.hpp index b29ea29a7914d..489886b8a3bce 100644 --- a/unified-runtime/source/adapters/level_zero/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/usm.hpp @@ -9,13 +9,14 @@ //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" +#include +#include "common.hpp" +#include "common/ur_ref_count.hpp" #include "enqueued_pool.hpp" #include "event.hpp" #include "ur_api.h" #include "ur_pool_manager.hpp" -#include #include usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); @@ -53,9 +54,13 @@ struct ur_usm_pool_handle_t_ : ur_object { ur_context_handle_t Context; + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: UsmPool *getPool(const usm::pool_descriptor &Desc); usm::pool_manager PoolManager; + + ur::RefCount RefCount; }; // Exception type to pass allocation errors diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp index cc2a88fb409e5..2a7fe2a3e7be9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp @@ -258,7 +258,7 @@ urCommandBufferCreateExp(ur_context_handle_t context, ur_device_handle_t device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - hCommandBuffer->RefCount.increment(); + hCommandBuffer->getRefCount().retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -266,7 +266,7 @@ urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - if (!hCommandBuffer->RefCount.decrementAndTest()) + if (!hCommandBuffer->getRefCount().release()) return UR_RESULT_SUCCESS; if (auto executionEvent = hCommandBuffer->getExecutionEventUnlocked()) { @@ -630,7 +630,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->getRefCount().getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp index 109ec09ce24fd..2ff945048ff74 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp @@ -9,15 +9,18 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include "../helpers/mutable_helpers.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "kernel.hpp" #include "lockable.hpp" #include "queue_api.hpp" -#include #include + struct kernel_command_handle; struct ur_exp_command_buffer_handle_t_ : public ur_object { @@ -62,6 +65,8 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { ur_event_handle_t createEventIfRequested(ur_exp_command_buffer_sync_point_t *retSyncPoint); + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: // Stores all sync points that are created by the command buffer. std::vector syncPoints; @@ -85,4 +90,6 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { ur_event_handle_t currentExecution = nullptr; v2::raii::cache_borrowed_event_pool eventPool; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index d29496ec39425..78da24390735e 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -916,7 +916,7 @@ ur_result_t ur_command_list_manager::appendNativeCommandExp( void ur_command_list_manager::recordSubmittedKernel( ur_kernel_handle_t hKernel) { submittedKernels.push_back(hKernel); - hKernel->RefCount.increment(); + hKernel->getRefCount().retain(); } ze_command_list_handle_t ur_command_list_manager::getZeCommandList() { diff --git a/unified-runtime/source/adapters/level_zero/v2/context.cpp b/unified-runtime/source/adapters/level_zero/v2/context.cpp index 9ad4a4e61d633..3c64de388ea4d 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.cpp @@ -80,12 +80,12 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, defaultUSMPool(this, nullptr), asyncPool(this, nullptr) {} ur_result_t ur_context_handle_t_::retain() { - RefCount.increment(); + RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t ur_context_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; delete this; @@ -201,7 +201,7 @@ ur_result_t urContextGetInfo(ur_context_handle_t hContext, case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(hContext->getDevices().size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hContext->RefCount.load()}); + return ReturnValue(uint32_t{hContext->getRefCount().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // TODO: this is currently not implemented return ReturnValue(uint8_t{false}); diff --git a/unified-runtime/source/adapters/level_zero/v2/context.hpp b/unified-runtime/source/adapters/level_zero/v2/context.hpp index 8d3cf8ca05579..a060b7070f8be 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.hpp @@ -14,6 +14,7 @@ #include "command_list_cache.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event_pool_cache.hpp" #include "usm.hpp" @@ -64,6 +65,8 @@ struct ur_context_handle_t_ : ur_object { // For that the Device or its root devices need to be in the context. bool isValidDevice(ur_device_handle_t Device) const; + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: const v2::raii::ze_context_handle_t hContext; const std::vector hDevices; @@ -81,4 +84,6 @@ struct ur_context_handle_t_ : ur_object { ur_usm_pool_handle_t_ defaultUSMPool; ur_usm_pool_handle_t_ asyncPool; std::list usmPoolHandles; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/event.cpp b/unified-runtime/source/adapters/level_zero/v2/event.cpp index 30816a9fcdd6f..3385d2f00b59c 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.cpp @@ -160,12 +160,12 @@ ze_event_handle_t ur_event_handle_t_::getZeEvent() const { } ur_result_t ur_event_handle_t_::retain() { - RefCount.increment(); + RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t ur_event_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; if (event_pool) { @@ -258,7 +258,7 @@ ur_result_t urEventGetInfo(ur_event_handle_t hEvent, ur_event_info_t propName, } } case UR_EVENT_INFO_REFERENCE_COUNT: { - return returnValue(hEvent->RefCount.load()); + return returnValue(hEvent->getRefCount().getCount()); } case UR_EVENT_INFO_COMMAND_QUEUE: { auto urQueueHandle = reinterpret_cast(hEvent->getQueue()) - diff --git a/unified-runtime/source/adapters/level_zero/v2/event.hpp b/unified-runtime/source/adapters/level_zero/v2/event.hpp index 0e9386578a2f6..9320fa34d4e36 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.hpp @@ -17,6 +17,7 @@ #include "adapters/level_zero/v2/queue_api.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event_provider.hpp" namespace v2 { @@ -112,10 +113,14 @@ struct ur_event_handle_t_ : ur_object { uint64_t getEventStartTimestmap() const; uint64_t getEventEndTimestamp(); + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: ur_event_handle_t_(ur_context_handle_t hContext, event_variant hZeEvent, v2::event_flags_t flags, v2::event_pool *pool); + ur::RefCount RefCount; + protected: ur_context_handle_t hContext; diff --git a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp index 99d4852f9ad4a..30c3c98809525 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp @@ -52,8 +52,8 @@ void event_pool::free(ur_event_handle_t event) { freelist.push_back(event); // The event is still in the pool, so we need to increment the refcount - assert(event->RefCount.load() == 0); - event->RefCount.increment(); + assert(event->getRefCount().getCount() == 0); + event->getRefCount().retain(); } event_provider *event_pool::getProvider() const { return provider.get(); } diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp index a2189b57536e8..9c4a9ae86a2cb 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp @@ -97,7 +97,7 @@ ur_kernel_handle_t_::ur_kernel_handle_t_( } ur_result_t ur_kernel_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; // manually release kernels to allow errors to be propagated @@ -370,7 +370,7 @@ urKernelCreateWithNativeHandle(ur_native_handle_t hNativeKernel, ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t hKernel) try { - hKernel->RefCount.increment(); + hKernel->getRefCount().retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -634,7 +634,7 @@ ur_result_t urKernelGetInfo(ur_kernel_handle_t hKernel, spills.size()); } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->RefCount.load()}); + return ReturnValue(uint32_t{hKernel->getRefCount().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: { auto attributes = hKernel->getSourceAttributes(); return ReturnValue(static_cast(attributes.data())); diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp index 0cabb888ac3be..bafda2d1c963b 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp @@ -13,6 +13,7 @@ #include "../program.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" struct ur_single_device_kernel_t { @@ -91,6 +92,8 @@ struct ur_kernel_handle_t_ : ur_object { ze_command_list_handle_t cmdList, wait_list_view &waitListView); + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: // Keep the program of the kernel. const ur_program_handle_t hProgram; @@ -116,4 +119,6 @@ struct ur_kernel_handle_t_ : ur_object { // pointer to any non-null kernel in deviceKernels ur_single_device_kernel_t *nonEmptyKernel; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.cpp b/unified-runtime/source/adapters/level_zero/v2/memory.cpp index b1f3829dd6967..f235939ab82a6 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.cpp @@ -671,7 +671,7 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, return returnValue(size_t{hMem->getBuffer()->getSize()}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return returnValue(hMem->getObject()->RefCount.load()); + return returnValue(hMem->getRefCount().getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -684,14 +684,14 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, } ur_result_t urMemRetain(ur_mem_handle_t hMem) try { - hMem->getObject()->RefCount.increment(); + hMem->getRefCount().retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urMemRelease(ur_mem_handle_t hMem) try { - if (!hMem->getObject()->RefCount.decrementAndTest()) + if (!hMem->getRefCount().release()) return UR_RESULT_SUCCESS; delete hMem; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.hpp b/unified-runtime/source/adapters/level_zero/v2/memory.hpp index 9c0dc66ef72b4..5ae8f810ed187 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.hpp @@ -19,6 +19,7 @@ #include "../image_common.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" using usm_unique_ptr_t = std::unique_ptr>; @@ -279,16 +280,10 @@ struct ur_mem_handle_t_ : ur::handle_base { mem); } - ur_object *getObject() { - return std::visit( - [](auto &&arg) -> ur_object * { - return static_cast(&arg); - }, - mem); - } - bool isImage() const { return std::holds_alternative(mem); } + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: template ur_mem_handle_t_(std::in_place_type_t, Args &&...args) @@ -299,4 +294,7 @@ struct ur_mem_handle_t_ : ur::handle_base { ur_discrete_buffer_handle_t, ur_shared_buffer_handle_t, ur_mem_sub_buffer_t, ur_mem_image_t> mem; + +private: + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp index 9831afdbc9e4c..cd5a8cc1fee86 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp @@ -13,11 +13,12 @@ #pragma once +#include + #include "../common.hpp" #include "queue_immediate_in_order.hpp" #include "queue_immediate_out_of_order.hpp" #include -#include struct ur_queue_handle_t_ : ur::handle_base { using data_variant = std::variant { ur_result_t queueRetain() { return std::visit( [](auto &q) { - q.RefCount.increment(); + q.getRefCount().retain(); return UR_RESULT_SUCCESS; }, queue_data); @@ -59,7 +60,7 @@ struct ur_queue_handle_t_ : ur::handle_base { ur_result_t queueRelease() { return std::visit( [queueHandle = this](auto &q) { - if (!q.RefCount.decrementAndTest()) + if (!q.getRefCount().release()) return UR_RESULT_SUCCESS; delete queueHandle; return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp index cc9b464333e70..85e6c7e2503c9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp @@ -60,7 +60,7 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hDevice); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{RefCount.load()}); + return ReturnValue(uint32_t{RefCount.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(flags); case UR_QUEUE_INFO_SIZE: diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp index 362a6ea31c9f4..7923e04923328 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp @@ -12,6 +12,7 @@ #include "../common.hpp" #include "../device.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" #include "event_pool_cache.hpp" @@ -32,6 +33,7 @@ struct ur_queue_immediate_in_order_t : ur_object, ur_queue_t_ { lockable commandListManager; ur_queue_flags_t flags; v2::raii::cache_borrowed_event_pool eventPool; + ur::RefCount RefCount; public: ur_queue_immediate_in_order_t(ur_context_handle_t, ur_device_handle_t, @@ -451,6 +453,8 @@ struct ur_queue_immediate_in_order_t : ur_object, ur_queue_t_ { numEventsInWaitList, phEventWaitList, createEventIfRequested(eventPool.get(), phEvent, this)); } + + ur::RefCount &getRefCount() noexcept { return RefCount; } }; } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp index bfb6079af3ea5..f0344eba267f9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp @@ -54,7 +54,7 @@ ur_result_t ur_queue_immediate_out_of_order_t::queueGetInfo( case UR_QUEUE_INFO_DEVICE: return ReturnValue(hDevice); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{RefCount.load()}); + return ReturnValue(uint32_t{RefCount.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(flags); case UR_QUEUE_INFO_SIZE: diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp index 1d0bf5636d58c..c34b208722630 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp @@ -11,6 +11,7 @@ #include "../common.hpp" #include "../device.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" @@ -49,6 +50,8 @@ struct ur_queue_immediate_out_of_order_t : ur_object, ur_queue_t_ { numCommandLists; } + ur::RefCount RefCount; + public: ur_queue_immediate_out_of_order_t(ur_context_handle_t, ur_device_handle_t, uint32_t ordinal, @@ -503,6 +506,8 @@ struct ur_queue_immediate_out_of_order_t : ur_object, ur_queue_t_ { numEventsInWaitList, phEventWaitList, createEventIfRequested(eventPool.get(), phEvent, this)); } + + ur::RefCount &getRefCount() noexcept { return RefCount; } }; } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.cpp b/unified-runtime/source/adapters/level_zero/v2/usm.cpp index f455fd3763554..85cd533b39704 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.cpp @@ -332,7 +332,7 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { - hPool->RefCount.increment(); + hPool->getRefCount().retain(); return UR_RESULT_SUCCESS; } catch (umf_result_t e) { return umf::umf2urResult(e); @@ -343,7 +343,7 @@ urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t hPool) try { - if (hPool->RefCount.decrementAndTest()) { + if (hPool->getRefCount().release()) { hPool->getContextHandle()->removeUsmPool(hPool); delete hPool; } @@ -369,7 +369,7 @@ ur_result_t urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->RefCount.load()); + return ReturnValue(hPool->getRefCount().getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->getContextHandle()); diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.hpp b/unified-runtime/source/adapters/level_zero/v2/usm.hpp index ff33b5f6bbed1..4b38d71a0bf29 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.hpp @@ -14,6 +14,7 @@ #include "../enqueued_pool.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event.hpp" #include "ur_pool_manager.hpp" @@ -49,9 +50,13 @@ struct ur_usm_pool_handle_t_ : ur_object { void cleanupPools(); void cleanupPoolsForQueue(void *hQueue); + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: ur_context_handle_t hContext; usm::pool_manager poolManager; UsmPool *getPool(const usm::pool_descriptor &desc); + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/native_cpu/adapter.cpp b/unified-runtime/source/adapters/native_cpu/adapter.cpp index 3fd6d4256825b..91b8c9976be61 100644 --- a/unified-runtime/source/adapters/native_cpu/adapter.cpp +++ b/unified-runtime/source/adapters/native_cpu/adapter.cpp @@ -10,17 +10,22 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "ur_api.h" struct ur_adapter_handle_t_ : ur::native_cpu::handle_base { - std::atomic RefCount = 0; logger::Logger &logger = logger::get_logger("native_cpu"); + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; } Adapter; UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - Adapter.RefCount++; + Adapter.getRefCount().retain(); *phAdapters = &Adapter; } if (pNumAdapters) { @@ -30,12 +35,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - Adapter.RefCount--; + Adapter.getRefCount().release(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - Adapter.RefCount++; + Adapter.getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -57,7 +62,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_NATIVE_CPU); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(Adapter.RefCount.load()); + return ReturnValue(Adapter.getRefCount().getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/native_cpu/common.hpp b/unified-runtime/source/adapters/native_cpu/common.hpp index e768a4b2ac7f4..7ca8964d514a8 100644 --- a/unified-runtime/source/adapters/native_cpu/common.hpp +++ b/unified-runtime/source/adapters/native_cpu/common.hpp @@ -44,22 +44,13 @@ struct ddi_getter { using handle_base = ur::handle_base; } // namespace ur::native_cpu -// Todo: replace this with a common helper once it is available -struct RefCounted : ur::native_cpu::handle_base { - std::atomic_uint32_t _refCount; - uint32_t incrementReferenceCount() { return ++_refCount; } - uint32_t decrementReferenceCount() { return --_refCount; } - RefCounted() : handle_base(), _refCount{1} {} - uint32_t getReferenceCount() const { return _refCount; } -}; - // Base class to store common data -struct ur_object : RefCounted { +struct ur_object { ur_shared_mutex Mutex; }; template inline void decrementOrDelete(T *refC) { - if (refC->decrementReferenceCount() == 0) + if (refC->getRefCount().release() == 0) delete refC; } diff --git a/unified-runtime/source/adapters/native_cpu/context.cpp b/unified-runtime/source/adapters/native_cpu/context.cpp index 5b7e8fc839884..38cb4efe0c45e 100644 --- a/unified-runtime/source/adapters/native_cpu/context.cpp +++ b/unified-runtime/source/adapters/native_cpu/context.cpp @@ -30,7 +30,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - hContext->incrementReferenceCount(); + hContext->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -51,7 +51,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, case UR_CONTEXT_INFO_DEVICES: return returnValue(hContext->_device); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return returnValue(uint32_t{hContext->getReferenceCount()}); + return returnValue(uint32_t{hContext->getRefCount().getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: return returnValue(true); case UR_CONTEXT_INFO_USM_FILL2D_SUPPORT: diff --git a/unified-runtime/source/adapters/native_cpu/context.hpp b/unified-runtime/source/adapters/native_cpu/context.hpp index b9d2d22dd1565..7acdffefc8dc8 100644 --- a/unified-runtime/source/adapters/native_cpu/context.hpp +++ b/unified-runtime/source/adapters/native_cpu/context.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "ur/ur.hpp" @@ -83,7 +84,7 @@ static usm_alloc_info get_alloc_info(void *ptr) { } // namespace native_cpu -struct ur_context_handle_t_ : RefCounted { +struct ur_context_handle_t_ { ur_context_handle_t_(ur_device_handle_t_ *phDevices) : _device{phDevices} {} ur_device_handle_t _device; @@ -135,7 +136,10 @@ struct ur_context_handle_t_ : RefCounted { return ptr; } + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: std::mutex alloc_mutex; std::set allocations; + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/native_cpu/event.cpp b/unified-runtime/source/adapters/native_cpu/event.cpp index 91b8fb302eb18..836404d4e30ff 100644 --- a/unified-runtime/source/adapters/native_cpu/event.cpp +++ b/unified-runtime/source/adapters/native_cpu/event.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->getRefCount().getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: return ReturnValue(hEvent->getExecutionStatus()); case UR_EVENT_INFO_CONTEXT: @@ -69,7 +69,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - hEvent->incrementReferenceCount(); + hEvent->getRefCount().retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/event.hpp b/unified-runtime/source/adapters/native_cpu/event.hpp index 479c671b38cd1..d5587ec5543a4 100644 --- a/unified-runtime/source/adapters/native_cpu/event.hpp +++ b/unified-runtime/source/adapters/native_cpu/event.hpp @@ -8,14 +8,17 @@ // //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" -#include "ur_api.h" + #include #include #include #include -struct ur_event_handle_t_ : RefCounted { +#include "common.hpp" +#include "common/ur_ref_count.hpp" +#include "ur_api.h" + +struct ur_event_handle_t_ { ur_event_handle_t_(ur_queue_handle_t queue, ur_command_t command_type); @@ -55,6 +58,8 @@ struct ur_event_handle_t_ : RefCounted { uint64_t get_end_timestamp() const { return timestamp_end; } + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: ur_queue_handle_t queue; ur_context_handle_t context; @@ -65,4 +70,5 @@ struct ur_event_handle_t_ : RefCounted { std::packaged_task callback; uint64_t timestamp_start = 0; uint64_t timestamp_end = 0; + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/native_cpu/kernel.cpp b/unified-runtime/source/adapters/native_cpu/kernel.cpp index ac11331357f39..d8742f52cfb67 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.cpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.cpp @@ -95,7 +95,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_FUNCTION_NAME: return ReturnValue(hKernel->_name.c_str()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->getReferenceCount()}); + return ReturnValue(uint32_t{hKernel->getRefCount().getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: return ReturnValue(""); case UR_KERNEL_INFO_SPILL_MEM_SIZE: @@ -194,7 +194,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSubGroupInfo( } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - hKernel->incrementReferenceCount(); + hKernel->getRefCount().retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/kernel.hpp b/unified-runtime/source/adapters/native_cpu/kernel.hpp index 8daf23feb65f5..32d4f7edd46a3 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.hpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.hpp @@ -8,13 +8,15 @@ #pragma once +#include +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" #include "nativecpu_state.hpp" #include "program.hpp" -#include #include -#include using nativecpu_kernel_t = void(void *const *, native_cpu::state *); using nativecpu_ptr_t = nativecpu_kernel_t *; @@ -27,7 +29,7 @@ struct local_arg_info_t { : argIndex(argIndex), argSize(argSize) {} }; -struct ur_kernel_handle_t_ : RefCounted { +struct ur_kernel_handle_t_ { ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name, nativecpu_task_t subhandler) @@ -188,10 +190,12 @@ struct ur_kernel_handle_t_ : RefCounted { void addPtrArg(void *Ptr, size_t Index) { Args.addPtrArg(Index, Ptr); } void addArgReference(ur_mem_handle_t Arg) { - Arg->incrementReferenceCount(); + Arg->getRefCount().getCount(); ReferencedArgs.push_back(Arg); } + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: void removeArgReferences() { for (auto arg : ReferencedArgs) @@ -209,4 +213,5 @@ struct ur_kernel_handle_t_ : RefCounted { std::optional MaxWGSize = std::nullopt; std::optional MaxLinearWGSize = std::nullopt; std::vector ReferencedArgs; + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/native_cpu/memory.hpp b/unified-runtime/source/adapters/native_cpu/memory.hpp index ca6e3e77f5e87..2e3c2d5531ea0 100644 --- a/unified-runtime/source/adapters/native_cpu/memory.hpp +++ b/unified-runtime/source/adapters/native_cpu/memory.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" struct ur_mem_handle_t_ : ur_object { @@ -43,8 +44,11 @@ struct ur_mem_handle_t_ : ur_object { char *_mem; bool _ownsMem; + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: const bool IsImage; + ur::RefCount RefCount; }; struct ur_buffer final : ur_mem_handle_t_ { diff --git a/unified-runtime/source/adapters/native_cpu/program.cpp b/unified-runtime/source/adapters/native_cpu/program.cpp index fee72f8a6bc3c..666d74a5e7132 100644 --- a/unified-runtime/source/adapters/native_cpu/program.cpp +++ b/unified-runtime/source/adapters/native_cpu/program.cpp @@ -171,7 +171,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - hProgram->incrementReferenceCount(); + hProgram->getRefCount().retain(); return UR_RESULT_SUCCESS; } @@ -205,7 +205,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return returnValue(hProgram->getReferenceCount()); + return returnValue(hProgram->getRefCount().getCount()); case UR_PROGRAM_INFO_CONTEXT: return returnValue(nullptr); case UR_PROGRAM_INFO_NUM_DEVICES: diff --git a/unified-runtime/source/adapters/native_cpu/program.hpp b/unified-runtime/source/adapters/native_cpu/program.hpp index d58412751e8f2..9e51f68b362c4 100644 --- a/unified-runtime/source/adapters/native_cpu/program.hpp +++ b/unified-runtime/source/adapters/native_cpu/program.hpp @@ -10,23 +10,22 @@ #pragma once +#include +#include + #include +#include "common/ur_ref_count.hpp" #include "context.hpp" -#include -#include - namespace native_cpu { using WGSize_t = std::array; } -struct ur_program_handle_t_ : RefCounted { +struct ur_program_handle_t_ { ur_program_handle_t_(ur_context_handle_t ctx, const unsigned char *pBinary) : _ctx{ctx}, _ptr{pBinary} {} - uint32_t getReferenceCount() const noexcept { return _refCount; } - ur_context_handle_t _ctx; const unsigned char *_ptr; struct _compare { @@ -41,6 +40,11 @@ struct ur_program_handle_t_ : RefCounted { std::unordered_map KernelMaxWorkGroupSizeMD; std::unordered_map KernelMaxLinearWorkGroupSizeMD; + + ur::RefCount &getRefCount() noexcept { return RefCount; } + +private: + ur::RefCount RefCount; }; // The nativecpu_entry struct is also defined as LLVM-IR in the diff --git a/unified-runtime/source/adapters/native_cpu/queue.cpp b/unified-runtime/source/adapters/native_cpu/queue.cpp index 5de7037519490..0bfaf83ee8afe 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.cpp +++ b/unified-runtime/source/adapters/native_cpu/queue.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->getDevice()); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->getRefCount().getCount()); case UR_QUEUE_INFO_EMPTY: return ReturnValue(hQueue->isEmpty()); default: @@ -48,7 +48,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - hQueue->incrementReferenceCount(); + hQueue->getRefCount().retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/queue.hpp b/unified-runtime/source/adapters/native_cpu/queue.hpp index 9fd28c3e7ff00..da831708fbcab 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.hpp +++ b/unified-runtime/source/adapters/native_cpu/queue.hpp @@ -8,12 +8,15 @@ // //===----------------------------------------------------------------------===// #pragma once + +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event.hpp" #include "ur_api.h" -#include -struct ur_queue_handle_t_ : RefCounted { +struct ur_queue_handle_t_ { ur_queue_handle_t_(ur_device_handle_t device, ur_context_handle_t context, const ur_queue_properties_t *pProps) : device(device), context(context), @@ -43,7 +46,7 @@ struct ur_queue_handle_t_ : RefCounted { auto ev = *events.begin(); // ur_event_handle_t_::wait removes itself from the events set in the // queue. - ev->incrementReferenceCount(); + ev->getRefCount().retain(); // Unlocking mutex for removeEvent and for event callbacks that may need // to acquire it. lock.unlock(); @@ -64,6 +67,8 @@ struct ur_queue_handle_t_ : RefCounted { return events.size() == 0; } + ur::RefCount &getRefCount() noexcept { return RefCount; } + private: ur_device_handle_t device; ur_context_handle_t context; @@ -71,4 +76,5 @@ struct ur_queue_handle_t_ : RefCounted { const bool inOrder; const bool profilingEnabled; std::mutex mutex; + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/common/ur_ref_count.hpp b/unified-runtime/source/common/ur_ref_count.hpp new file mode 100644 index 0000000000000..7815e1dab65d0 --- /dev/null +++ b/unified-runtime/source/common/ur_ref_count.hpp @@ -0,0 +1,36 @@ +/* + * + * Copyright (C) 2025 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM + * Exceptions. See LICENSE.TXT + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + */ +#ifndef URREFCOUNT_HPP +#define URREFCOUNT_HPP 1 + +#include +#include + +namespace ur { + +class RefCount { +public: + RefCount(uint32_t count = 1) : Count(count) {} + RefCount(const RefCount &) = delete; + RefCount &operator=(const RefCount &) = delete; + + uint32_t getCount() const noexcept { return Count.load(); } + uint32_t retain() { return ++Count; } + bool release() { return --Count == 0; } + void reset(uint32_t value = 1) { Count = value; } + +private: + std::atomic_uint32_t Count; +}; + +} // namespace ur + +#endif // URREFCOUNT_HPP