Skip to content

Commit edf67c0

Browse files
committed
Merge branch 'lwawrzyniak/deferred-unload' into 'main'
Fix graph capture errors caused by module unloading Closes GH-401 See merge request omniverse/warp!932
2 parents 4b68286 + e79ed07 commit edf67c0

File tree

4 files changed

+99
-3
lines changed

4 files changed

+99
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- Fix for occasional failure to update .meta files into Warp kernel cache on Windows
2626
- Mark kernel arrays as written to when passed to `wp.atomic_add()` or `wp.atomic_sub()`
2727
- Fix the OpenGL renderer not being able to run without CUDA ([GH-344](https://github.com/NVIDIA/warp/issues/344)).
28+
- Fix errors during graph capture caused by module unloading ([GH-401](https://github.com/NVIDIA/warp/issues/401)).
2829

2930
## [1.5.0] - 2024-12-02
3031

warp/native/warp.cu

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ struct FreeInfo
187187
bool is_async = false;
188188
};
189189

190+
// Information used when deferring module unloading.
191+
struct ModuleInfo
192+
{
193+
void* context = NULL;
194+
void* module = NULL;
195+
};
196+
190197
static std::unordered_map<CUfunction, std::string> g_kernel_names;
191198

192199
// cached info for all devices, indexed by ordinal
@@ -214,6 +221,9 @@ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
214221
// Call free_deferred_allocs() to release.
215222
static std::vector<FreeInfo> g_deferred_free_list;
216223

224+
// Modules that cannot be unloaded immediately get queued here.
225+
// Call unload_deferred_modules() to release.
226+
static std::vector<ModuleInfo> g_deferred_module_list;
217227

218228
void cuda_set_context_restore_policy(bool always_restore)
219229
{
@@ -410,6 +420,31 @@ static int free_deferred_allocs(void* context = NULL)
410420
return num_freed_allocs;
411421
}
412422

423+
static int unload_deferred_modules(void* context = NULL)
424+
{
425+
if (g_deferred_module_list.empty() || !g_captures.empty())
426+
return 0;
427+
428+
int num_unloaded_modules = 0;
429+
for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
430+
{
431+
// free the module if it matches the given context or if the context is unspecified
432+
const ModuleInfo& module_info = *it;
433+
if (module_info.context == context || !context)
434+
{
435+
cuda_unload_module(module_info.context, module_info.module);
436+
++num_unloaded_modules;
437+
it = g_deferred_module_list.erase(it);
438+
}
439+
else
440+
{
441+
++it;
442+
}
443+
}
444+
445+
return num_unloaded_modules;
446+
}
447+
413448
static void CUDART_CB on_graph_destroy(void* user_data)
414449
{
415450
if (!user_data)
@@ -1920,6 +1955,8 @@ void cuda_context_synchronize(void* context)
19201955
check_cu(cuCtxSynchronize_f());
19211956
}
19221957

1958+
unload_deferred_modules(context);
1959+
19231960
// check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
19241961
}
19251962

@@ -2542,7 +2579,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
25422579

25432580
// process deferred free list if no more captures are ongoing
25442581
if (g_captures.empty())
2582+
{
25452583
free_deferred_allocs();
2584+
unload_deferred_modules();
2585+
}
25462586

25472587
if (graph_ret)
25482588
*graph_ret = graph_exec;
@@ -3104,9 +3144,20 @@ void* cuda_load_module(void* context, const char* path)
31043144

31053145
void cuda_unload_module(void* context, void* module)
31063146
{
3107-
ContextGuard guard(context);
3108-
3109-
check_cu(cuModuleUnload_f((CUmodule)module));
3147+
// ensure there are no graph captures in progress
3148+
if (g_captures.empty())
3149+
{
3150+
ContextGuard guard(context);
3151+
check_cu(cuModuleUnload_f((CUmodule)module));
3152+
}
3153+
else
3154+
{
3155+
// defer until graph capture completes
3156+
ModuleInfo module_info;
3157+
module_info.context = context ? context : get_current_context();
3158+
module_info.module = module;
3159+
g_deferred_module_list.push_back(module_info);
3160+
}
31103161
}
31113162

31123163

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2+
# NVIDIA CORPORATION and its licensors retain all intellectual property
3+
# and proprietary rights in and to this software, related documentation
4+
# and any modifications thereto. Any use, reproduction, disclosure or
5+
# distribution of this software and related documentation without an express
6+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
7+
8+
"""Dummy module used in test_reload.py"""
9+
10+
import warp as wp
11+
12+
13+
@wp.kernel
14+
def k():
15+
pass

warp/tests/test_reload.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,32 @@ def foo(a: wp.array(dtype=int)):
241241
test.assertEqual(a.numpy()[0], 42)
242242

243243

244+
def test_module_unload_during_graph_capture(test, device):
245+
@wp.kernel
246+
def foo(a: wp.array(dtype=int)):
247+
a[0] = 42
248+
249+
# preload module before graph capture
250+
wp.load_module(device=device)
251+
252+
# load another module to test unloading during graph capture
253+
other_module = wp.get_module("warp.tests.aux_test_module_unload")
254+
other_module.load(device)
255+
256+
with wp.ScopedDevice(device):
257+
a = wp.zeros(1, dtype=int)
258+
259+
with wp.ScopedCapture(force_module_load=False) as capture:
260+
wp.launch(foo, dim=1, inputs=[a])
261+
262+
# unloading a module during graph capture should be fine (deferred until capture completes)
263+
other_module.unload()
264+
265+
wp.capture_launch(capture.graph)
266+
267+
test.assertEqual(a.numpy()[0], 42)
268+
269+
244270
devices = get_test_devices()
245271
cuda_devices = get_cuda_test_devices()
246272

@@ -258,6 +284,9 @@ class TestReload(unittest.TestCase):
258284
add_function_test(
259285
TestReload, "test_graph_launch_after_module_reload", test_graph_launch_after_module_reload, devices=cuda_devices
260286
)
287+
add_function_test(
288+
TestReload, "test_module_unload_during_graph_capture", test_module_unload_during_graph_capture, devices=cuda_devices
289+
)
261290

262291

263292
if __name__ == "__main__":

0 commit comments

Comments
 (0)