@@ -396,12 +396,17 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes
396
396
}
397
397
}
398
398
#endif
399
- if (!transposeBackendRequiresNccl (grid_desc->config .transpose_comm_backend ) &&
400
- !haloBackendRequiresNccl (grid_desc->config .halo_comm_backend )) {
401
- CHECK_NCCL (ncclCommDestroy (handle->nccl_comm ));
402
- handle->nccl_comm = nullptr ;
403
- CHECK_NCCL (ncclCommDestroy (handle->nccl_local_comm ));
404
- handle->nccl_local_comm = nullptr ;
399
+ if (transposeBackendRequiresNccl (grid_desc->config .transpose_comm_backend ) ||
400
+ haloBackendRequiresNccl (grid_desc->config .halo_comm_backend )) {
401
+ handle->n_grid_descs_using_nccl ++;
402
+ } else {
403
+ // Destroy NCCL communicator to reclaim resources if not used
404
+ if (handle->nccl_comm && handle->nccl_local_comm && handle->n_grid_descs_using_nccl == 0 ) {
405
+ CHECK_NCCL (ncclCommDestroy (handle->nccl_comm ));
406
+ handle->nccl_comm = nullptr ;
407
+ CHECK_NCCL (ncclCommDestroy (handle->nccl_local_comm ));
408
+ handle->nccl_local_comm = nullptr ;
409
+ }
405
410
}
406
411
407
412
*grid_desc_in = grid_desc;
@@ -437,6 +442,19 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe
437
442
if (e) { CHECK_CUDA (cudaEventDestroy (e)); }
438
443
}
439
444
445
+ if (transposeBackendRequiresNccl (grid_desc->config .transpose_comm_backend ) ||
446
+ haloBackendRequiresNccl (grid_desc->config .halo_comm_backend )) {
447
+ handle->n_grid_descs_using_nccl --;
448
+
449
+ // Destroy NCCL communicator to reclaim resources if not used
450
+ if (handle->nccl_comm && handle->nccl_local_comm && handle->n_grid_descs_using_nccl == 0 ) {
451
+ CHECK_NCCL (ncclCommDestroy (handle->nccl_comm ));
452
+ handle->nccl_comm = nullptr ;
453
+ CHECK_NCCL (ncclCommDestroy (handle->nccl_local_comm ));
454
+ handle->nccl_local_comm = nullptr ;
455
+ }
456
+ }
457
+
440
458
#ifdef ENABLE_NVSHMEM
441
459
if (transposeBackendRequiresNvshmem (grid_desc->config .transpose_comm_backend ) ||
442
460
haloBackendRequiresNvshmem (grid_desc->config .halo_comm_backend )) {
@@ -447,6 +465,14 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe
447
465
nvshmem_team_destroy (grid_desc->col_comm_info .nvshmem_team );
448
466
}
449
467
handle->n_grid_descs_using_nvshmem --;
468
+
469
+ // Finalize nvshmem to reclaim symmetric heap memory if not used
470
+ if (handle->nvshmem_initialized && handle->n_grid_descs_using_nvshmem == 0 ) {
471
+ nvshmem_finalize ();
472
+ handle->nvshmem_initialized = false ;
473
+ handle->nvshmem_allocations .clear ();
474
+ handle->nvshmem_allocation_size = 0 ;
475
+ }
450
476
}
451
477
#endif
452
478
0 commit comments