@@ -159,12 +159,96 @@ static int accelerator_cuda_get_device_id(CUcontext mem_ctx) {
159159 return dev_id ;
160160}
161161
162+ static int accelerator_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
163+ int * dev_id )
164+ {
165+ #if OPAL_CUDA_VMM_SUPPORT
166+ static int device_count = -1 ;
167+ static int mpool_supported = -1 ;
168+ CUresult result ;
169+ CUmemoryPool mpool ;
170+ CUmemAccess_flags flags ;
171+ CUmemLocation location ;
172+
173+ if (mpool_supported <= 0 ) {
174+ if (mpool_supported == -1 ) {
175+ if (device_count == -1 ) {
176+ result = cuDeviceGetCount (& device_count );
177+ if (result != CUDA_SUCCESS || (0 == device_count )) {
178+ mpool_supported = 0 ; /* never check again */
179+ device_count = 0 ;
180+ return 0 ;
181+ }
182+ }
183+
184+ /* assume uniformity of devices */
185+ result = cuDeviceGetAttribute (& mpool_supported ,
186+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
187+ if (result != CUDA_SUCCESS ) {
188+ mpool_supported = 0 ;
189+ }
190+ }
191+ if (0 == mpool_supported ) {
192+ return 0 ;
193+ }
194+ }
195+
196+ result = cuPointerGetAttribute (& mpool , CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
197+ dbuf );
198+ if (CUDA_SUCCESS != result ) {
199+ return 0 ;
200+ }
201+
202+ /* check if device has access */
203+ for (int i = 0 ; i < device_count ; i ++ ) {
204+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
205+ location .id = i ;
206+ result = cuMemPoolGetAccess (& flags , mpool , & location );
207+ if ((CUDA_SUCCESS == result ) &&
208+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
209+ * mem_type = CU_MEMORYTYPE_DEVICE ;
210+ * dev_id = i ;
211+ return 1 ;
212+ }
213+ }
214+
215+ /* host must have access as device access possibility is exhausted */
216+ * mem_type = CU_MEMORYTYPE_HOST ;
217+ * dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
218+ return 0 ;
219+ #endif
220+
221+ return 0 ;
222+ }
223+
224+ static int accelerator_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
225+ {
226+ CUresult result ;
227+ unsigned int flags ;
228+ int active ;
229+
230+ result = cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
231+ if (CUDA_SUCCESS != result ) {
232+ return OPAL_ERROR ;
233+ }
234+
235+ if (active ) {
236+ result = cuDevicePrimaryCtxRetain (pctx , dev_id );
237+ return OPAL_SUCCESS ;
238+ }
239+
240+ return OPAL_ERROR ;
241+ }
242+
162243static int accelerator_cuda_check_addr (const void * addr , int * dev_id , uint64_t * flags )
163244{
164245 CUresult result ;
165246 int is_vmm = 0 ;
247+ int is_mpool_ptr = 0 ;
166248 int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
249+ int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
167250 CUmemorytype vmm_mem_type = 0 ;
251+ CUmemorytype mpool_mem_type = 0 ;
168252 CUmemorytype mem_type = 0 ;
169253 CUdeviceptr dbuf = (CUdeviceptr ) addr ;
170254 CUcontext ctx = NULL , mem_ctx = NULL ;
@@ -177,6 +261,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
177261 * flags = 0 ;
178262
179263 is_vmm = accelerator_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
264+ is_mpool_ptr = accelerator_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
180265
181266#if OPAL_CUDA_GET_ATTRIBUTES
182267 uint32_t is_managed = 0 ;
@@ -210,6 +295,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
210295 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
211296 mem_type = CU_MEMORYTYPE_DEVICE ;
212297 * dev_id = vmm_dev_id ;
298+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
299+ mem_type = CU_MEMORYTYPE_DEVICE ;
300+ * dev_id = mpool_dev_id ;
213301 } else {
214302 /* Host memory, nothing to do here */
215303 return 0 ;
@@ -220,6 +308,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
220308 } else {
221309 if (is_vmm ) {
222310 * dev_id = vmm_dev_id ;
311+ } else if (is_mpool_ptr ) {
312+ * dev_id = mpool_dev_id ;
223313 } else {
224314 /* query the device from the context */
225315 * dev_id = accelerator_cuda_get_device_id (mem_ctx );
@@ -238,13 +328,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
238328 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
239329 mem_type = CU_MEMORYTYPE_DEVICE ;
240330 * dev_id = vmm_dev_id ;
331+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
332+ mem_type = CU_MEMORYTYPE_DEVICE ;
333+ * dev_id = mpool_dev_id ;
241334 } else {
242335 /* Host memory, nothing to do here */
243336 return 0 ;
244337 }
245338 } else {
246339 if (is_vmm ) {
247340 * dev_id = vmm_dev_id ;
341+ } else if (is_mpool_ptr ) {
342+ * dev_id = mpool_dev_id ;
248343 } else {
249344 result = cuPointerGetAttribute (& mem_ctx ,
250345 CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
@@ -278,14 +373,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278373 return OPAL_ERROR ;
279374 }
280375#endif /* OPAL_CUDA_GET_ATTRIBUTES */
281- if (is_vmm ) {
282- /* This function is expected to set context if pointer is device
283- * accessible but VMM allocations have NULL context associated
284- * which cannot be set against the calling thread */
285- opal_output (0 ,
286- "CUDA: unable to set context with the given pointer"
287- "ptr=%p aborting..." , addr );
288- return OPAL_ERROR ;
376+ if (is_vmm || is_mpool_ptr ) {
377+ if (OPAL_SUCCESS ==
378+ accelerator_cuda_get_primary_context (
379+ is_vmm ? vmm_dev_id : mpool_dev_id , & mem_ctx )) {
380+ /* As VMM/mempool allocations have no context associated
381+ * with them, check if device primary context can be set */
382+ } else {
383+ opal_output (0 ,
384+ "CUDA: unable to set ctx with the given pointer"
385+ "ptr=%p aborting..." , addr );
386+ return OPAL_ERROR ;
387+ }
289388 }
290389
291390 result = cuCtxSetCurrent (mem_ctx );
0 commit comments