diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp index e5332686fcffb..361a781e8f9b6 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp @@ -31,6 +31,7 @@ DLWRAP(cuDeviceGet, 2) DLWRAP(cuDeviceGetAttribute, 3) DLWRAP(cuDeviceGetCount, 1) DLWRAP(cuFuncGetAttribute, 3) +DLWRAP(cuFuncSetAttribute, 3) // Device info DLWRAP(cuDeviceGetName, 3) diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h index 1c5b421768894..b6c022c8e7e8b 100644 --- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h +++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h @@ -258,6 +258,7 @@ typedef enum CUdevice_attribute_enum { typedef enum CUfunction_attribute_enum { CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, } CUfunction_attribute; typedef enum CUctx_flags_enum { @@ -295,6 +296,7 @@ CUresult cuDeviceGet(CUdevice *, int); CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice); CUresult cuDeviceGetCount(int *); CUresult cuFuncGetAttribute(int *, CUfunction_attribute, CUfunction); +CUresult cuFuncSetAttribute(CUfunction, CUfunction_attribute, int); // Device info CUresult cuDeviceGetName(char *, int, CUdevice); diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index f1164074f9ea9..9a9237a74b0ef 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -160,6 +160,9 @@ struct CUDAKernelTy : public GenericKernelTy { private: /// The CUDA kernel function to execute. CUfunction Func; + /// The maximum amount of dynamic shared memory per thread group. By default, + /// this is set to 48 KB. + mutable uint32_t MaxDynCGroupMemLimit = 49152; }; /// Class wrapping a CUDA stream reference. These are the objects handled by the @@ -1302,6 +1305,16 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice, if (GenericDevice.getRPCServer()) GenericDevice.Plugin.getRPCServer().Thread->notify(); + // In case we require more memory than the current limit. + if (MaxDynCGroupMem >= MaxDynCGroupMemLimit) { + CUresult AttrResult = cuFuncSetAttribute( + Func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, MaxDynCGroupMem); + Plugin::check( + AttrResult, + "Error in cuLaunchKernel while setting the memory limits: %s"); + MaxDynCGroupMemLimit = MaxDynCGroupMem; + } + CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2], NumThreads[0], NumThreads[1], NumThreads[2], MaxDynCGroupMem, Stream, nullptr, Config);