@@ -187,6 +187,13 @@ struct FreeInfo
187
187
bool is_async = false ;
188
188
};
189
189
190
+ // Information used when deferring module unloading.
191
+ struct ModuleInfo
192
+ {
193
+ void * context = NULL ;
194
+ void * module = NULL ;
195
+ };
196
+
190
197
static std::unordered_map<CUfunction, std::string> g_kernel_names;
191
198
192
199
// cached info for all devices, indexed by ordinal
@@ -214,6 +221,9 @@ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
214
221
// Call free_deferred_allocs() to release.
215
222
static std::vector<FreeInfo> g_deferred_free_list;
216
223
224
+ // Modules that cannot be unloaded immediately get queued here.
225
+ // Call unload_deferred_modules() to release.
226
+ static std::vector<ModuleInfo> g_deferred_module_list;
217
227
218
228
void cuda_set_context_restore_policy (bool always_restore)
219
229
{
@@ -410,6 +420,31 @@ static int free_deferred_allocs(void* context = NULL)
410
420
return num_freed_allocs;
411
421
}
412
422
423
+ static int unload_deferred_modules (void * context = NULL )
424
+ {
425
+ if (g_deferred_module_list.empty () || !g_captures.empty ())
426
+ return 0 ;
427
+
428
+ int num_unloaded_modules = 0 ;
429
+ for (auto it = g_deferred_module_list.begin (); it != g_deferred_module_list.end (); /* noop*/ )
430
+ {
431
+ // free the module if it matches the given context or if the context is unspecified
432
+ const ModuleInfo& module_info = *it;
433
+ if (module_info.context == context || !context)
434
+ {
435
+ cuda_unload_module (module_info.context , module_info.module );
436
+ ++num_unloaded_modules;
437
+ it = g_deferred_module_list.erase (it);
438
+ }
439
+ else
440
+ {
441
+ ++it;
442
+ }
443
+ }
444
+
445
+ return num_unloaded_modules;
446
+ }
447
+
413
448
static void CUDART_CB on_graph_destroy (void * user_data)
414
449
{
415
450
if (!user_data)
@@ -1920,6 +1955,8 @@ void cuda_context_synchronize(void* context)
1920
1955
check_cu (cuCtxSynchronize_f ());
1921
1956
}
1922
1957
1958
+ unload_deferred_modules (context);
1959
+
1923
1960
// check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
1924
1961
}
1925
1962
@@ -2542,7 +2579,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2542
2579
2543
2580
// process deferred free list if no more captures are ongoing
2544
2581
if (g_captures.empty ())
2582
+ {
2545
2583
free_deferred_allocs ();
2584
+ unload_deferred_modules ();
2585
+ }
2546
2586
2547
2587
if (graph_ret)
2548
2588
*graph_ret = graph_exec;
@@ -3104,9 +3144,20 @@ void* cuda_load_module(void* context, const char* path)
3104
3144
3105
3145
void cuda_unload_module (void * context, void * module )
3106
3146
{
3107
- ContextGuard guard (context);
3108
-
3109
- check_cu (cuModuleUnload_f ((CUmodule)module ));
3147
+ // ensure there are no graph captures in progress
3148
+ if (g_captures.empty ())
3149
+ {
3150
+ ContextGuard guard (context);
3151
+ check_cu (cuModuleUnload_f ((CUmodule)module ));
3152
+ }
3153
+ else
3154
+ {
3155
+ // defer until graph capture completes
3156
+ ModuleInfo module_info;
3157
+ module_info.context = context ? context : get_current_context ();
3158
+ module_info.module = module ;
3159
+ g_deferred_module_list.push_back (module_info);
3160
+ }
3110
3161
}
3111
3162
3112
3163
0 commit comments