Skip to content

[UR][Offload] Global query/read/write support #19482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/offload/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS:
return ReturnValue(uint32_t{3});
case UR_DEVICE_INFO_COMPILER_AVAILABLE:
case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT:
return ReturnValue(true);
// Unimplemented features
case UR_DEVICE_INFO_PROGRAM_SET_SPECIALIZATION_CONSTANTS:
case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT:
case UR_DEVICE_INFO_USM_POOL_SUPPORT:
case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP:
case UR_DEVICE_INFO_IMAGE_SUPPORT:
Expand Down
84 changes: 54 additions & 30 deletions unified-runtime/source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead,
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

namespace {
ur_result_t doMemcpy(ur_queue_handle_t hQueue, void *DestPtr,
ol_device_handle_t DestDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t size, bool blocking,
uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
// Ignore wait list for now
(void)numEventsInWaitList;
(void)phEventWaitList;
//

ol_event_handle_t EventOut = nullptr;

char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);

OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice,
DevPtr + offset, hQueue->OffloadDevice, size,
phEvent ? &EventOut : nullptr));
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DestPtr, DestDevice, SrcPtr,
SrcDevice, size, phEvent ? &EventOut : nullptr));

if (blockingRead) {
if (blocking) {
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
}

Expand All @@ -124,37 +122,63 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(

return UR_RESULT_SUCCESS;
}
} // namespace

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead,
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);

return doMemcpy(hQueue, pDst, Adapter->HostDevice, DevPtr + offset,
hQueue->OffloadDevice, size, blockingRead,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

// Ignore wait list for now
(void)numEventsInWaitList;
(void)phEventWaitList;
//

ol_event_handle_t EventOut = nullptr;

char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);

OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset,
hQueue->OffloadDevice, pSrc, Adapter->HostDevice,
size, phEvent ? &EventOut : nullptr));
return doMemcpy(hQueue, DevPtr + offset, hQueue->OffloadDevice, pSrc,
Adapter->HostDevice, size, blockingWrite, numEventsInWaitList,
phEventWaitList, phEvent);
}

if (blockingWrite) {
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blockingRead, size_t count, size_t offset, void *pDst,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
void *Ptr;
if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name,
nullptr, &Ptr)) {
return Err;
}

if (phEvent) {
auto *Event = new ur_event_handle_t_();
Event->OffloadEvent = EventOut;
*phEvent = Event;
return doMemcpy(hQueue, pDst, Adapter->HostDevice,
reinterpret_cast<const char *>(Ptr) + offset,
hQueue->OffloadDevice, count, blockingRead,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
void *Ptr;
if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name,
nullptr, &Ptr)) {
return Err;
}

return UR_RESULT_SUCCESS;
return doMemcpy(hQueue, reinterpret_cast<char *>(Ptr) + offset,
hQueue->OffloadDevice, pSrc, Adapter->HostDevice, count,
blockingWrite, numEventsInWaitList, phEventWaitList, phEvent);
}

ur_result_t enqueueNoOp(ur_queue_handle_t hQueue, ur_event_handle_t *phEvent) {
Expand Down
101 changes: 77 additions & 24 deletions unified-runtime/source/adapters/offload/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
#ifdef UR_CUDA_ENABLED
ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext,
const uint8_t *Binary, size_t Length,
ur_program_handle_t *phProgram) {
ur_program_handle_t hProgram) {
uint8_t *RealBinary;
size_t RealLength;
CUlinkState State;
Expand All @@ -48,25 +48,17 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext,
fprintf(stderr, "Performed CUDA bin workaround (size = %lu)\n", RealLength);
#endif

ur_program_handle_t Program = new ur_program_handle_t_();
auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary,
RealLength, &Program->OffloadProgram);
RealLength, &hProgram->OffloadProgram);

// Program owns the linked module now
cuLinkDestroy(State);

if (Res != OL_SUCCESS) {
delete Program;
return offloadResultToUR(Res);
}

*phProgram = Program;

return UR_RESULT_SUCCESS;
return offloadResultToUR(Res);
}
#else
ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *,
size_t, ur_program_handle_t *) {
size_t, ur_program_handle_t) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
#endif
Expand All @@ -76,7 +68,8 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *,
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
ur_context_handle_t hContext, uint32_t numDevices,
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
const ur_program_properties_t *, ur_program_handle_t *phProgram) {
const ur_program_properties_t *pProperties,
ur_program_handle_t *phProgram) {
if (numDevices > 1) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
Expand All @@ -100,24 +93,55 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
}
}

ur_program_handle_t Program = new ur_program_handle_t_{};
Program->URContext = hContext;
Program->Binary = RealBinary;
Program->BinarySizeInBytes = RealLength;

// Parse properties
if (pProperties) {
if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
} else if (pProperties->count == 0 && pProperties->pMetadatas != nullptr) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

auto Length = pProperties->count;
auto Metadata = pProperties->pMetadatas;
for (size_t i = 0; i < Length; ++i) {
const ur_program_metadata_t MetadataElement = Metadata[i];
std::string MetadataElementName{MetadataElement.pName};

auto [Prefix, Tag] = splitMetadataName(MetadataElementName);

if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
const char *MetadataValPtr =
reinterpret_cast<const char *>(MetadataElement.value.pData) +
sizeof(std::uint64_t);
const char *MetadataValPtrEnd =
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
Program->GlobalIDMD[Prefix] =
std::string{MetadataValPtr, MetadataValPtrEnd};
}
}
}

ur_result_t Res;
ol_platform_backend_t Backend;
olGetPlatformInfo(phDevices[0]->Platform->OffloadPlatform,
OL_PLATFORM_INFO_BACKEND, sizeof(Backend), &Backend);
if (Backend == OL_PLATFORM_BACKEND_CUDA) {
return ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength,
phProgram);
Res =
ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength, Program);
} else {
Res = offloadResultToUR(olCreateProgram(hContext->Device->OffloadDevice,
RealBinary, RealLength,
&Program->OffloadProgram));
}

ur_program_handle_t Program = new ur_program_handle_t_{};
Program->URContext = hContext;
Program->Binary = RealBinary;
Program->BinarySizeInBytes = RealLength;
auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary,
RealLength, &Program->OffloadProgram);

if (Res != OL_SUCCESS) {
if (Res != UR_RESULT_SUCCESS) {
delete Program;
return offloadResultToUR(Res);
return Res;
}

*phProgram = Program;
Expand Down Expand Up @@ -240,3 +264,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
ur_program_handle_t, uint32_t, const ur_specialization_constant_info_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
ur_device_handle_t, ur_program_handle_t hProgram,
const char *pGlobalVariableName, size_t *pGlobalVariableSizeRet,
void **ppGlobalVariablePointerRet) {
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(pGlobalVariableName);
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
return UR_RESULT_ERROR_INVALID_VALUE;
std::string DeviceGlobalName = DeviceGlobalNameIt->second;

ol_symbol_handle_t Symbol;
auto Err = olGetSymbol(hProgram->OffloadProgram, DeviceGlobalName.c_str(),
OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Symbol);
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
return UR_RESULT_ERROR_INVALID_VALUE;
}
OL_RETURN_ON_ERR(Err);

if (pGlobalVariableSizeRet) {
OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol,
OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
sizeof(size_t), pGlobalVariableSizeRet));
}
OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol,
OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
sizeof(void *), ppGlobalVariablePointerRet));

return UR_RESULT_SUCCESS;
}
2 changes: 2 additions & 0 deletions unified-runtime/source/adapters/offload/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ struct ur_program_handle_t_ : RefCounted {
ur_context_handle_t URContext;
const uint8_t *Binary;
size_t BinarySizeInBytes;
// A mapping from mangled global names -> names in the binary
std::unordered_map<std::string, std::string> GlobalIDMD;
};
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(
pDdiTable->pfnCreateWithNativeHandle = urProgramCreateWithNativeHandle;
pDdiTable->pfnGetBuildInfo = nullptr;
pDdiTable->pfnGetFunctionPointer = nullptr;
pDdiTable->pfnGetGlobalVariablePointer = nullptr;
pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer;
pDdiTable->pfnGetInfo = urProgramGetInfo;
pDdiTable->pfnGetNativeHandle = urProgramGetNativeHandle;
pDdiTable->pfnLink = nullptr;
Expand Down Expand Up @@ -168,8 +168,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
if (UR_RESULT_SUCCESS != result) {
return result;
}
pDdiTable->pfnDeviceGlobalVariableRead = nullptr;
pDdiTable->pfnDeviceGlobalVariableWrite = nullptr;
pDdiTable->pfnDeviceGlobalVariableRead = urEnqueueDeviceGlobalVariableRead;
pDdiTable->pfnDeviceGlobalVariableWrite = urEnqueueDeviceGlobalVariableWrite;
pDdiTable->pfnEventsWait = nullptr;
pDdiTable->pfnEventsWaitWithBarrier = nullptr;
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;
Expand Down